Skip to content

Commit 7751e1e

Browse files
committed
Rate limit apply for gateway connections
1 parent f28bac1 commit 7751e1e

File tree

8 files changed

+268
-41
lines changed

8 files changed

+268
-41
lines changed

platform-api/src/config/config.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ type JWT struct {
6666

6767
// WebSocket holds WebSocket-specific configuration
6868
type WebSocket struct {
69-
MaxConnections int `envconfig:"WS_MAX_CONNECTIONS" default:"1000"`
70-
ConnectionTimeout int `envconfig:"WS_CONNECTION_TIMEOUT" default:"30"` // seconds
71-
RateLimitPerMin int `envconfig:"WS_RATE_LIMIT_PER_MINUTE" default:"10"`
69+
MaxConnections int `envconfig:"WS_MAX_CONNECTIONS" default:"1000"`
70+
ConnectionTimeout int `envconfig:"WS_CONNECTION_TIMEOUT" default:"30"` // seconds
71+
RateLimitPerMin int `envconfig:"WS_RATE_LIMIT_PER_MINUTE" default:"10"`
72+
MaxConnectionsPerOrg int `envconfig:"WS_MAX_CONNECTIONS_PER_ORG" default:"3"`
7273
}
7374

7475
// Database holds database-specific configuration

platform-api/src/internal/handler/websocket.go

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,34 @@ func (h *WebSocketHandler) Connect(c *gin.Context) {
105105
transport := ws.NewWebSocketTransport(conn)
106106

107107
// Register connection with manager
108-
connection, err := h.manager.Register(gateway.ID, transport, apiKey)
108+
connection, err := h.manager.Register(gateway.ID, transport, apiKey, gateway.OrganizationID)
109109
if err != nil {
110-
log.Printf("[ERROR] Connection registration failed: gatewayID=%s error=%v", gateway.ID, err)
111-
// Send error message before closing
112-
errorMsg := map[string]string{
113-
"type": "error",
114-
"message": err.Error(),
115-
}
116-
if jsonErr, _ := json.Marshal(errorMsg); jsonErr != nil {
117-
conn.WriteMessage(websocket.TextMessage, jsonErr)
110+
log.Printf("[ERROR] Connection registration failed: gatewayID=%s orgID=%s error=%v",
111+
gateway.ID, gateway.OrganizationID, err)
112+
113+
// Check if this is an org connection limit error
114+
if orgLimitErr, ok := err.(*ws.OrgConnectionLimitError); ok {
115+
errorMsg := map[string]interface{}{
116+
"type": "error",
117+
"code": "ORG_CONNECTION_LIMIT_EXCEEDED",
118+
"message": "Organization connection limit reached",
119+
"currentCount": orgLimitErr.CurrentCount,
120+
"maxAllowed": orgLimitErr.MaxAllowed,
121+
}
122+
if jsonErr, _ := json.Marshal(errorMsg); jsonErr != nil {
123+
conn.WriteMessage(websocket.TextMessage, jsonErr)
124+
}
125+
log.Printf("[WARN] Organization connection limit exceeded: orgID=%s count=%d max=%d",
126+
orgLimitErr.OrganizationID, orgLimitErr.CurrentCount, orgLimitErr.MaxAllowed)
127+
} else {
128+
// Generic error
129+
errorMsg := map[string]string{
130+
"type": "error",
131+
"message": err.Error(),
132+
}
133+
if jsonErr, _ := json.Marshal(errorMsg); jsonErr != nil {
134+
conn.WriteMessage(websocket.TextMessage, jsonErr)
135+
}
118136
}
119137
conn.Close()
120138
return

platform-api/src/internal/repository/organization.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func (r *OrganizationRepo) CreateOrganization(org *model.Organization) error {
4646
VALUES (?, ?, ?, ?, ?, ?)
4747
`
4848
_, err := r.db.Exec(r.db.Rebind(query), org.ID, org.Handle, org.Name, org.Region, org.CreatedAt, org.UpdatedAt)
49+
4950
return err
5051
}
5152

@@ -118,6 +119,7 @@ func (r *OrganizationRepo) UpdateOrganization(org *model.Organization) error {
118119
WHERE uuid = ?
119120
`
120121
_, err := r.db.Exec(r.db.Rebind(query), org.Handle, org.Name, org.Region, org.UpdatedAt, org.ID)
122+
121123
return err
122124
}
123125

platform-api/src/internal/server/server.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,10 @@ func StartPlatformAPIServer(cfg *config.Server) (*Server, error) {
7979

8080
// Initialize WebSocket manager first (needed for GatewayEventsService)
8181
wsConfig := websocket.ManagerConfig{
82-
MaxConnections: cfg.WebSocket.MaxConnections,
83-
HeartbeatInterval: 20 * time.Second,
84-
HeartbeatTimeout: time.Duration(cfg.WebSocket.ConnectionTimeout) * time.Second,
82+
MaxConnections: cfg.WebSocket.MaxConnections,
83+
HeartbeatInterval: 20 * time.Second,
84+
HeartbeatTimeout: time.Duration(cfg.WebSocket.ConnectionTimeout) * time.Second,
85+
MaxConnectionsPerOrg: cfg.WebSocket.MaxConnectionsPerOrg,
8586
}
8687
wsManager := websocket.NewManager(wsConfig)
8788

@@ -145,8 +146,9 @@ func StartPlatformAPIServer(cfg *config.Server) (*Server, error) {
145146
gitHandler.RegisterRoutes(router)
146147
deploymentHandler.RegisterRoutes(router)
147148

148-
log.Printf("[INFO] WebSocket manager initialized: maxConnections=%d heartbeatTimeout=%ds rateLimitPerMin=%d",
149-
cfg.WebSocket.MaxConnections, cfg.WebSocket.ConnectionTimeout, cfg.WebSocket.RateLimitPerMin)
149+
log.Printf("[INFO] WebSocket manager initialized: maxConnections=%d heartbeatTimeout=%ds rateLimitPerMin=%d maxConnectionsPerOrg=%d",
150+
cfg.WebSocket.MaxConnections, cfg.WebSocket.ConnectionTimeout, cfg.WebSocket.RateLimitPerMin,
151+
cfg.WebSocket.MaxConnectionsPerOrg)
150152

151153
return &Server{
152154
router: router,

platform-api/src/internal/websocket/connection.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ type Connection struct {
3636
// Used to distinguish between multiple connections from the same gateway (clustering).
3737
ConnectionID string
3838

39+
// OrganizationID identifies the organization that owns this gateway connection.
40+
// Used for per-organization connection limit tracking and cleanup.
41+
OrganizationID string
42+
3943
// ConnectedAt records when the connection was established
4044
ConnectedAt time.Time
4145

@@ -68,19 +72,21 @@ type Connection struct {
6872
// - connectionID: Unique identifier for this connection instance
6973
// - transport: Transport implementation (e.g., WebSocketTransport)
7074
// - authToken: API key used for authentication
75+
// - orgID: UUID of the organization that owns the gateway
7176
//
7277
// Returns a fully initialized Connection ready for message delivery.
73-
func NewConnection(gatewayID, connectionID string, transport Transport, authToken string) *Connection {
78+
func NewConnection(gatewayID, connectionID string, transport Transport, authToken string, orgID string) *Connection {
7479
now := time.Now()
7580
return &Connection{
76-
GatewayID: gatewayID,
77-
ConnectionID: connectionID,
78-
ConnectedAt: now,
79-
LastHeartbeat: now,
80-
Transport: transport,
81-
AuthToken: authToken,
82-
DeliveryStats: &DeliveryStats{},
83-
closed: false,
81+
GatewayID: gatewayID,
82+
ConnectionID: connectionID,
83+
OrganizationID: orgID,
84+
ConnectedAt: now,
85+
LastHeartbeat: now,
86+
Transport: transport,
87+
AuthToken: authToken,
88+
DeliveryStats: &DeliveryStats{},
89+
closed: false,
8490
}
8591
}
8692

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) 2026, WSO2 LLC. (http://www.wso2.org) All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package websocket
18+
19+
import "fmt"
20+
21+
// OrgConnectionLimitError is returned when an organization has reached its connection limit
22+
type OrgConnectionLimitError struct {
23+
OrganizationID string
24+
CurrentCount int
25+
MaxAllowed int
26+
}
27+
28+
func (e *OrgConnectionLimitError) Error() string {
29+
return fmt.Sprintf("organization %s has reached maximum connection limit: %d/%d",
30+
e.OrganizationID, e.CurrentCount, e.MaxAllowed)
31+
}
32+
33+
// IsOrgConnectionLimitError checks if an error is an OrgConnectionLimitError
34+
func IsOrgConnectionLimitError(err error) bool {
35+
_, ok := err.(*OrgConnectionLimitError)
36+
return ok
37+
}

platform-api/src/internal/websocket/manager.go

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ type Manager struct {
5353
// heartbeatTimeout specifies when to consider a connection dead (default 30s)
5454
heartbeatTimeout time.Duration
5555

56+
// orgLimiter enforces per-organization connection limits
57+
orgLimiter *OrgConnectionLimiter
58+
5659
// shutdownCtx is used to signal graceful shutdown to all connection goroutines
5760
shutdownCtx context.Context
5861
shutdownFn context.CancelFunc
@@ -63,17 +66,19 @@ type Manager struct {
6366

6467
// ManagerConfig contains configuration parameters for the connection manager
6568
type ManagerConfig struct {
66-
MaxConnections int // Maximum concurrent connections (default 1000)
67-
HeartbeatInterval time.Duration // Ping interval (default 20s)
68-
HeartbeatTimeout time.Duration // Pong timeout (default 30s)
69+
MaxConnections int // Maximum concurrent connections (default 1000)
70+
HeartbeatInterval time.Duration // Ping interval (default 20s)
71+
HeartbeatTimeout time.Duration // Pong timeout (default 30s)
72+
MaxConnectionsPerOrg int // Maximum connections per organization (default 3)
6973
}
7074

7175
// DefaultManagerConfig returns sensible default configuration values
7276
func DefaultManagerConfig() ManagerConfig {
7377
return ManagerConfig{
74-
MaxConnections: 1000,
75-
HeartbeatInterval: 20 * time.Second,
76-
HeartbeatTimeout: 30 * time.Second,
78+
MaxConnections: 1000,
79+
HeartbeatInterval: 20 * time.Second,
80+
HeartbeatTimeout: 30 * time.Second,
81+
MaxConnectionsPerOrg: 3,
7782
}
7883
}
7984

@@ -86,6 +91,7 @@ func NewManager(config ManagerConfig) *Manager {
8691
maxConnections: config.MaxConnections,
8792
heartbeatInterval: config.HeartbeatInterval,
8893
heartbeatTimeout: config.HeartbeatTimeout,
94+
orgLimiter: NewOrgConnectionLimiter(config.MaxConnectionsPerOrg),
8995
shutdownCtx: ctx,
9096
shutdownFn: cancel,
9197
}
@@ -98,25 +104,37 @@ func NewManager(config ManagerConfig) *Manager {
98104
// - gatewayID: UUID of the authenticated gateway
99105
// - transport: Transport implementation for message delivery
100106
// - authToken: API key used for authentication
107+
// - orgID: UUID of the organization that owns the gateway
101108
//
102109
// Returns the Connection instance and any error encountered.
103110
//
104111
// Design decision: Support multiple connections per gateway ID by storing
105112
// connections in a slice. This enables gateway clustering where multiple
106113
// instances share the same gateway identity.
107-
func (m *Manager) Register(gatewayID string, transport Transport, authToken string) (*Connection, error) {
108-
// Check connection limit
114+
func (m *Manager) Register(gatewayID string, transport Transport, authToken string,
115+
orgID string) (*Connection, error) {
116+
117+
// Create connection ID early so we can use it for org limiter
118+
connectionID := uuid.New().String()
119+
120+
// Check per-org limit first
121+
if err := m.orgLimiter.AddConnection(orgID, connectionID); err != nil {
122+
return nil, err
123+
}
124+
125+
// Check global connection limit
109126
m.mu.Lock()
110127
if m.connectionCount >= m.maxConnections {
111128
m.mu.Unlock()
129+
// Rollback org limiter
130+
m.orgLimiter.RemoveConnection(orgID, connectionID)
112131
return nil, fmt.Errorf("maximum connection limit reached (%d)", m.maxConnections)
113132
}
114133
m.connectionCount++
115134
m.mu.Unlock()
116135

117-
// Create connection with unique connection ID
118-
connectionID := uuid.New().String()
119-
conn := NewConnection(gatewayID, connectionID, transport, authToken)
136+
// Create connection
137+
conn := NewConnection(gatewayID, connectionID, transport, authToken, orgID)
120138

121139
// Add connection to registry
122140
connsInterface, _ := m.connections.LoadOrStore(gatewayID, []*Connection{})
@@ -128,8 +146,8 @@ func (m *Manager) Register(gatewayID string, transport Transport, authToken stri
128146
m.wg.Add(1)
129147
go m.monitorHeartbeat(conn)
130148

131-
log.Printf("[INFO] Gateway connected: gatewayID=%s connectionID=%s totalConnections=%d",
132-
gatewayID, connectionID, m.GetConnectionCount())
149+
log.Printf("[INFO] Gateway connected: gatewayID=%s connectionID=%s orgID=%s totalConnections=%d orgConnections=%d",
150+
gatewayID, connectionID, orgID, m.GetConnectionCount(), m.orgLimiter.GetOrgConnectionCount(orgID))
133151

134152
return conn, nil
135153
}
@@ -176,13 +194,18 @@ func (m *Manager) Unregister(gatewayID, connectionID string) {
176194
gatewayID, connectionID, err)
177195
}
178196

197+
// Remove from org limiter
198+
if removed.OrganizationID != "" {
199+
m.orgLimiter.RemoveConnection(removed.OrganizationID, connectionID)
200+
}
201+
179202
// Decrement connection count
180203
m.mu.Lock()
181204
m.connectionCount--
182205
m.mu.Unlock()
183206

184-
log.Printf("[INFO] Gateway disconnected: gatewayID=%s connectionID=%s totalConnections=%d",
185-
gatewayID, connectionID, m.GetConnectionCount())
207+
log.Printf("[INFO] Gateway disconnected: gatewayID=%s connectionID=%s orgID=%s totalConnections=%d",
208+
gatewayID, connectionID, removed.OrganizationID, m.GetConnectionCount())
186209
}
187210

188211
// GetConnections retrieves all connections for a specific gateway ID.
@@ -301,3 +324,13 @@ func (m *Manager) Shutdown() {
301324

302325
log.Println("[INFO] WebSocket manager shutdown complete")
303326
}
327+
328+
// GetOrgConnectionStats returns connection statistics for a specific organization
329+
func (m *Manager) GetOrgConnectionStats(orgID string) OrgConnectionStats {
330+
return m.orgLimiter.GetOrgStats(orgID)
331+
}
332+
333+
// GetAllOrgConnectionStats returns connection counts for all organizations
334+
func (m *Manager) GetAllOrgConnectionStats() map[string]int {
335+
return m.orgLimiter.GetAllOrgConnectionCounts()
336+
}

0 commit comments

Comments
 (0)