Skip to content

Commit 22089f7

Browse files
committed
Rate limit apply for gateway connections
1 parent f28bac1 commit 22089f7

File tree

14 files changed

+359
-57
lines changed

14 files changed

+359
-57
lines changed

.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,4 @@
6868
],
6969
},
7070
]
71-
}
71+
}

platform-api/src/config/config.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ type JWT struct {
6666

6767
// WebSocket holds WebSocket-specific configuration
6868
type WebSocket struct {
69-
MaxConnections int `envconfig:"WS_MAX_CONNECTIONS" default:"1000"`
70-
ConnectionTimeout int `envconfig:"WS_CONNECTION_TIMEOUT" default:"30"` // seconds
71-
RateLimitPerMin int `envconfig:"WS_RATE_LIMIT_PER_MINUTE" default:"10"`
69+
MaxConnections int `envconfig:"WS_MAX_CONNECTIONS" default:"1000"`
70+
ConnectionTimeout int `envconfig:"WS_CONNECTION_TIMEOUT" default:"30"` // seconds
71+
RateLimitPerMin int `envconfig:"WS_RATE_LIMIT_PER_MINUTE" default:"10"`
72+
FreeOrgMaxConnections int `envconfig:"WS_FREE_ORG_MAX_CONNECTIONS" default:"3"`
73+
PaidOrgMaxConnections int `envconfig:"WS_PAID_ORG_MAX_CONNECTIONS" default:"10"`
7274
}
7375

7476
// Database holds database-specific configuration

platform-api/src/internal/constants/constants.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,15 @@ const (
106106
AssociationTypeGateway = "gateway"
107107
AssociationTypeDevPortal = "dev_portal"
108108
)
109+
110+
// Organization Tier Constants
111+
const (
112+
OrgTierFree = "free"
113+
OrgTierPaid = "paid"
114+
)
115+
116+
// ValidOrgTiers Valid organization tiers
117+
var ValidOrgTiers = map[string]bool{
118+
OrgTierFree: true,
119+
OrgTierPaid: true,
120+
}

platform-api/src/internal/database/schema.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ CREATE TABLE IF NOT EXISTS organizations (
2121
handle VARCHAR(255) UNIQUE NOT NULL,
2222
name VARCHAR(255) NOT NULL,
2323
region VARCHAR(63) NOT NULL,
24+
tier VARCHAR(20) NOT NULL DEFAULT 'free' CHECK (tier IN ('free', 'paid')),
2425
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
2526
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
2627
);

platform-api/src/internal/dto/organization.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type CreateOrganizationRequest struct {
2727
Handle string `json:"handle" yaml:"handle" binding:"required"`
2828
Name string `json:"name" yaml:"name" binding:"required"`
2929
Region string `json:"region" yaml:"region" binding:"required"`
30+
Tier string `json:"tier,omitempty" yaml:"tier,omitempty"`
3031
}
3132

3233
// Organization represents an organization entity in the API management platform
@@ -35,6 +36,7 @@ type Organization struct {
3536
Handle string `json:"handle" yaml:"handle"`
3637
Name string `json:"name" yaml:"name"`
3738
Region string `json:"region" yaml:"region"`
39+
Tier string `json:"tier" yaml:"tier"`
3840
CreatedAt time.Time `json:"createdAt" yaml:"createdAt"`
3941
UpdatedAt time.Time `json:"updatedAt" yaml:"updatedAt"`
4042
}

platform-api/src/internal/handler/websocket.go

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"sync"
2424
"time"
2525

26+
"platform-api/src/internal/constants"
2627
"platform-api/src/internal/dto"
2728
"platform-api/src/internal/service"
2829
"platform-api/src/internal/utils"
@@ -36,6 +37,7 @@ import (
3637
type WebSocketHandler struct {
3738
manager *ws.Manager
3839
gatewayService *service.GatewayService
40+
orgService *service.OrganizationService
3941
upgrader websocket.Upgrader
4042

4143
// Rate limiting: track connection attempts per IP
@@ -45,10 +47,12 @@ type WebSocketHandler struct {
4547
}
4648

4749
// NewWebSocketHandler creates a new WebSocket handler
48-
func NewWebSocketHandler(manager *ws.Manager, gatewayService *service.GatewayService, rateLimitCount int) *WebSocketHandler {
50+
func NewWebSocketHandler(manager *ws.Manager, gatewayService *service.GatewayService,
51+
orgService *service.OrganizationService, rateLimitCount int) *WebSocketHandler {
4952
return &WebSocketHandler{
5053
manager: manager,
5154
gatewayService: gatewayService,
55+
orgService: orgService,
5256
upgrader: websocket.Upgrader{
5357
CheckOrigin: func(r *http.Request) bool {
5458
// TODO: Implement proper origin checking in production
@@ -93,6 +97,22 @@ func (h *WebSocketHandler) Connect(c *gin.Context) {
9397
return
9498
}
9599

100+
// Fetch organization to get tier
101+
org, err := h.orgService.GetOrganizationByUUID(gateway.OrganizationID)
102+
if err != nil {
103+
log.Printf("[ERROR] Failed to fetch organization: gatewayID=%s orgID=%s error=%v",
104+
gateway.ID, gateway.OrganizationID, err)
105+
c.JSON(http.StatusInternalServerError, utils.NewErrorResponse(500, "Internal Server Error",
106+
"Failed to fetch organization details"))
107+
return
108+
}
109+
110+
// Default to free tier if organization not found or tier not set
111+
orgTier := constants.OrgTierFree
112+
if org != nil && org.Tier != "" {
113+
orgTier = org.Tier
114+
}
115+
96116
// Upgrade HTTP connection to WebSocket
97117
conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil)
98118
if err != nil {
@@ -105,16 +125,35 @@ func (h *WebSocketHandler) Connect(c *gin.Context) {
105125
transport := ws.NewWebSocketTransport(conn)
106126

107127
// Register connection with manager
108-
connection, err := h.manager.Register(gateway.ID, transport, apiKey)
128+
connection, err := h.manager.Register(gateway.ID, transport, apiKey, gateway.OrganizationID, orgTier)
109129
if err != nil {
110-
log.Printf("[ERROR] Connection registration failed: gatewayID=%s error=%v", gateway.ID, err)
111-
// Send error message before closing
112-
errorMsg := map[string]string{
113-
"type": "error",
114-
"message": err.Error(),
115-
}
116-
if jsonErr, _ := json.Marshal(errorMsg); jsonErr != nil {
117-
conn.WriteMessage(websocket.TextMessage, jsonErr)
130+
log.Printf("[ERROR] Connection registration failed: gatewayID=%s orgID=%s error=%v",
131+
gateway.ID, gateway.OrganizationID, err)
132+
133+
// Check if this is an org connection limit error
134+
if orgLimitErr, ok := err.(*ws.OrgConnectionLimitError); ok {
135+
errorMsg := map[string]interface{}{
136+
"type": "error",
137+
"code": "ORG_CONNECTION_LIMIT_EXCEEDED",
138+
"message": "Organization connection limit reached",
139+
"currentCount": orgLimitErr.CurrentCount,
140+
"maxAllowed": orgLimitErr.MaxAllowed,
141+
"tier": orgLimitErr.Tier,
142+
}
143+
if jsonErr, _ := json.Marshal(errorMsg); jsonErr != nil {
144+
conn.WriteMessage(websocket.TextMessage, jsonErr)
145+
}
146+
log.Printf("[WARN] Organization connection limit exceeded: orgID=%s tier=%s count=%d max=%d",
147+
orgLimitErr.OrganizationID, orgLimitErr.Tier, orgLimitErr.CurrentCount, orgLimitErr.MaxAllowed)
148+
} else {
149+
// Generic error
150+
errorMsg := map[string]string{
151+
"type": "error",
152+
"message": err.Error(),
153+
}
154+
if jsonErr, _ := json.Marshal(errorMsg); jsonErr != nil {
155+
conn.WriteMessage(websocket.TextMessage, jsonErr)
156+
}
118157
}
119158
conn.Close()
120159
return

platform-api/src/internal/model/organization.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type Organization struct {
2727
Handle string `json:"handle" db:"handle"`
2828
Name string `json:"name" db:"name"`
2929
Region string `json:"region" db:"region"`
30+
Tier string `json:"tier" db:"tier"`
3031
CreatedAt time.Time `json:"createdAt" db:"created_at"`
3132
UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
3233
}

platform-api/src/internal/repository/organization.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,24 @@ func (r *OrganizationRepo) CreateOrganization(org *model.Organization) error {
4242
org.UpdatedAt = time.Now()
4343

4444
query := `
45-
INSERT INTO organizations (uuid, handle, name, region, created_at, updated_at)
46-
VALUES (?, ?, ?, ?, ?, ?)
45+
INSERT INTO organizations (uuid, handle, name, region, tier, created_at, updated_at)
46+
VALUES (?, ?, ?, ?, ?, ?, ?)
4747
`
48-
_, err := r.db.Exec(r.db.Rebind(query), org.ID, org.Handle, org.Name, org.Region, org.CreatedAt, org.UpdatedAt)
48+
_, err := r.db.Exec(r.db.Rebind(query), org.ID, org.Handle, org.Name, org.Region, org.Tier, org.CreatedAt, org.UpdatedAt)
49+
4950
return err
5051
}
5152

5253
// GetOrganizationByIdOrHandle retrieves an organization by id or handle
5354
func (r *OrganizationRepo) GetOrganizationByIdOrHandle(id, handle string) (*model.Organization, error) {
5455
org := &model.Organization{}
5556
query := `
56-
SELECT uuid, handle, name, region, created_at, updated_at
57+
SELECT uuid, handle, name, region, tier, created_at, updated_at
5758
FROM organizations
5859
WHERE uuid = ? OR handle = ?
5960
`
6061
err := r.db.QueryRow(r.db.Rebind(query), id, handle).Scan(
61-
&org.ID, &org.Handle, &org.Name, &org.Region, &org.CreatedAt, &org.UpdatedAt,
62+
&org.ID, &org.Handle, &org.Name, &org.Region, &org.Tier, &org.CreatedAt, &org.UpdatedAt,
6263
)
6364
if err != nil {
6465
if errors.Is(err, sql.ErrNoRows) {
@@ -73,12 +74,12 @@ func (r *OrganizationRepo) GetOrganizationByIdOrHandle(id, handle string) (*mode
7374
func (r *OrganizationRepo) GetOrganizationByUUID(orgId string) (*model.Organization, error) {
7475
org := &model.Organization{}
7576
query := `
76-
SELECT uuid, handle, name, region, created_at, updated_at
77+
SELECT uuid, handle, name, region, tier, created_at, updated_at
7778
FROM organizations
7879
WHERE uuid = ?
7980
`
8081
err := r.db.QueryRow(r.db.Rebind(query), orgId).Scan(
81-
&org.ID, &org.Handle, &org.Name, &org.Region, &org.CreatedAt, &org.UpdatedAt,
82+
&org.ID, &org.Handle, &org.Name, &org.Region, &org.Tier, &org.CreatedAt, &org.UpdatedAt,
8283
)
8384
if err != nil {
8485
if errors.Is(err, sql.ErrNoRows) {
@@ -93,12 +94,12 @@ func (r *OrganizationRepo) GetOrganizationByUUID(orgId string) (*model.Organizat
9394
func (r *OrganizationRepo) GetOrganizationByHandle(handle string) (*model.Organization, error) {
9495
org := &model.Organization{}
9596
query := `
96-
SELECT uuid, handle, name, region, created_at, updated_at
97+
SELECT uuid, handle, name, region, tier, created_at, updated_at
9798
FROM organizations
9899
WHERE handle = ?
99100
`
100101
err := r.db.QueryRow(r.db.Rebind(query), handle).Scan(
101-
&org.ID, &org.Handle, &org.Name, &org.Region, &org.CreatedAt, &org.UpdatedAt,
102+
&org.ID, &org.Handle, &org.Name, &org.Region, &org.Tier, &org.CreatedAt, &org.UpdatedAt,
102103
)
103104
if err != nil {
104105
if errors.Is(err, sql.ErrNoRows) {
@@ -114,10 +115,11 @@ func (r *OrganizationRepo) UpdateOrganization(org *model.Organization) error {
114115
org.UpdatedAt = time.Now()
115116
query := `
116117
UPDATE organizations
117-
SET handle = ?, name = ?, region = ?, updated_at = ?
118+
SET handle = ?, name = ?, region = ?, tier = ?, updated_at = ?
118119
WHERE uuid = ?
119120
`
120-
_, err := r.db.Exec(r.db.Rebind(query), org.Handle, org.Name, org.Region, org.UpdatedAt, org.ID)
121+
_, err := r.db.Exec(r.db.Rebind(query), org.Handle, org.Name, org.Region, org.Tier, org.UpdatedAt, org.ID)
122+
121123
return err
122124
}
123125

@@ -131,7 +133,7 @@ func (r *OrganizationRepo) DeleteOrganization(orgId string) error {
131133
// ListOrganizations retrieves organizations with pagination
132134
func (r *OrganizationRepo) ListOrganizations(limit, offset int) ([]*model.Organization, error) {
133135
query := `
134-
SELECT uuid, handle, name, region, created_at, updated_at
136+
SELECT uuid, handle, name, region, tier, created_at, updated_at
135137
FROM organizations
136138
ORDER BY created_at DESC
137139
LIMIT ? OFFSET ?
@@ -145,7 +147,7 @@ func (r *OrganizationRepo) ListOrganizations(limit, offset int) ([]*model.Organi
145147
var organizations []*model.Organization
146148
for rows.Next() {
147149
org := &model.Organization{}
148-
err := rows.Scan(&org.ID, &org.Handle, &org.Name, &org.Region, &org.CreatedAt, &org.UpdatedAt)
150+
err := rows.Scan(&org.ID, &org.Handle, &org.Name, &org.Region, &org.Tier, &org.CreatedAt, &org.UpdatedAt)
149151
if err != nil {
150152
return nil, err
151153
}

platform-api/src/internal/server/server.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,11 @@ func StartPlatformAPIServer(cfg *config.Server) (*Server, error) {
7979

8080
// Initialize WebSocket manager first (needed for GatewayEventsService)
8181
wsConfig := websocket.ManagerConfig{
82-
MaxConnections: cfg.WebSocket.MaxConnections,
83-
HeartbeatInterval: 20 * time.Second,
84-
HeartbeatTimeout: time.Duration(cfg.WebSocket.ConnectionTimeout) * time.Second,
82+
MaxConnections: cfg.WebSocket.MaxConnections,
83+
HeartbeatInterval: 20 * time.Second,
84+
HeartbeatTimeout: time.Duration(cfg.WebSocket.ConnectionTimeout) * time.Second,
85+
FreeOrgMaxConnections: cfg.WebSocket.FreeOrgMaxConnections,
86+
PaidOrgMaxConnections: cfg.WebSocket.PaidOrgMaxConnections,
8587
}
8688
wsManager := websocket.NewManager(wsConfig)
8789

@@ -109,7 +111,7 @@ func StartPlatformAPIServer(cfg *config.Server) (*Server, error) {
109111
apiHandler := handler.NewAPIHandler(apiService)
110112
devPortalHandler := handler.NewDevPortalHandler(devPortalService)
111113
gatewayHandler := handler.NewGatewayHandler(gatewayService)
112-
wsHandler := handler.NewWebSocketHandler(wsManager, gatewayService, cfg.WebSocket.RateLimitPerMin)
114+
wsHandler := handler.NewWebSocketHandler(wsManager, gatewayService, orgService, cfg.WebSocket.RateLimitPerMin)
113115
internalGatewayHandler := handler.NewGatewayInternalAPIHandler(gatewayService, internalGatewayService)
114116
gitHandler := handler.NewGitHandler(gitService)
115117
deploymentHandler := handler.NewDeploymentHandler(deploymentService)
@@ -145,8 +147,9 @@ func StartPlatformAPIServer(cfg *config.Server) (*Server, error) {
145147
gitHandler.RegisterRoutes(router)
146148
deploymentHandler.RegisterRoutes(router)
147149

148-
log.Printf("[INFO] WebSocket manager initialized: maxConnections=%d heartbeatTimeout=%ds rateLimitPerMin=%d",
149-
cfg.WebSocket.MaxConnections, cfg.WebSocket.ConnectionTimeout, cfg.WebSocket.RateLimitPerMin)
150+
log.Printf("[INFO] WebSocket manager initialized: maxConnections=%d heartbeatTimeout=%ds rateLimitPerMin=%d freeOrgMaxConns=%d paidOrgMaxConns=%d",
151+
cfg.WebSocket.MaxConnections, cfg.WebSocket.ConnectionTimeout, cfg.WebSocket.RateLimitPerMin,
152+
cfg.WebSocket.FreeOrgMaxConnections, cfg.WebSocket.PaidOrgMaxConnections)
150153

151154
return &Server{
152155
router: router,

platform-api/src/internal/service/organization.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ func NewOrganizationService(orgRepo repository.OrganizationRepository,
4949
}
5050

5151
func (s *OrganizationService) RegisterOrganization(id string, handle string, name string, region string) (*dto.Organization, error) {
52+
return s.RegisterOrganizationWithTier(id, handle, name, region, constants.OrgTierFree)
53+
}
54+
55+
func (s *OrganizationService) RegisterOrganizationWithTier(id string, handle string, name string, region string, tier string) (*dto.Organization, error) {
5256
// Validate handle is URL friendly
5357
if !s.isURLFriendly(handle) {
5458
return nil, constants.ErrInvalidHandle
@@ -70,12 +74,18 @@ func (s *OrganizationService) RegisterOrganization(id string, handle string, nam
7074
name = handle // Default name to handle if not provided
7175
}
7276

77+
// Default tier to free if not specified or invalid
78+
if tier == "" || !constants.ValidOrgTiers[tier] {
79+
tier = constants.OrgTierFree
80+
}
81+
7382
// Create organization in platform-api database first
7483
org := &dto.Organization{
7584
ID: id,
7685
Handle: handle,
7786
Name: name,
7887
Region: region,
88+
Tier: tier,
7989
CreatedAt: time.Now(),
8090
UpdatedAt: time.Now(),
8191
}
@@ -151,6 +161,7 @@ func (s *OrganizationService) dtoToModel(dto *dto.Organization) *model.Organizat
151161
Handle: dto.Handle,
152162
Name: dto.Name,
153163
Region: dto.Region,
164+
Tier: dto.Tier,
154165
CreatedAt: dto.CreatedAt,
155166
UpdatedAt: dto.UpdatedAt,
156167
}
@@ -166,6 +177,7 @@ func (s *OrganizationService) modelToDTO(model *model.Organization) *dto.Organiz
166177
Handle: model.Handle,
167178
Name: model.Name,
168179
Region: model.Region,
180+
Tier: model.Tier,
169181
CreatedAt: model.CreatedAt,
170182
UpdatedAt: model.UpdatedAt,
171183
}

0 commit comments

Comments
 (0)