Skip to content

Commit bcd6cd5

Browse files
dulacpcemalkilic
andauthored
feat(oauth-server): store and enforce token_endpoint_auth_method (#2300)
## Problem I noticed there was a TODO for storing the `token_endpoint_auth_method` value. While integrating with Claude.ai's OAuth flow, we discovered that returning `client_secret_basic` for all clients (regardless of their actual registration) was breaking the authentication flow. Claude.ai strictly validates the auth method returned during client registration, so it was critical for us to return the correct value. Per [RFC 7591 Section 2](https://datatracker.ietf.org/doc/html/rfc7591#section-2): > If unspecified or omitted, the default is "client_secret_basic" For public clients, the default is `none` since they don't have a client secret. ## Solution Added proper storage and enforcement of `token_endpoint_auth_method`: ### Database Changes - Added `token_endpoint_auth_method` TEXT column (NOT NULL) to `oauth_clients` table - Migration sets default values for existing clients based on their `client_type`: - `confidential` → `client_secret_basic` - `public` → `none` ### Behavior - New clients get `token_endpoint_auth_method` persisted during registration - Token endpoint validates that the authentication method used matches the registered method - Returns the correct `token_endpoint_auth_method` in client registration responses --------- Signed-off-by: Pierre Dulac <dulacpier@gmail.com> Signed-off-by: Pierre Dulac <pierre@entropia.io> Co-authored-by: Cemal Kılıç <cemalkilic@users.noreply.github.com>
1 parent 2d3dbc6 commit bcd6cd5

File tree

8 files changed

+231
-41
lines changed

8 files changed

+231
-41
lines changed

internal/api/middleware.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,18 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
136136
func (a *API) requireOAuthClientAuth(w http.ResponseWriter, r *http.Request) (context.Context, error) {
137137
ctx := r.Context()
138138

139-
clientID, clientSecret, err := oauthserver.ExtractClientCredentials(r)
139+
creds, err := oauthserver.ExtractClientCredentials(r)
140140
if err != nil {
141141
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials: %s", err.Error())
142142
}
143143

144144
// If no client credentials provided, continue without client authentication
145-
if clientID == "" {
145+
if creds.ClientID == "" {
146146
return ctx, nil
147147
}
148148

149149
// Parse client_id as UUID
150-
clientUUID, err := uuid.FromString(clientID)
150+
clientUUID, err := uuid.FromString(creds.ClientID)
151151
if err != nil {
152152
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client_id format")
153153
}
@@ -162,8 +162,13 @@ func (a *API) requireOAuthClientAuth(w http.ResponseWriter, r *http.Request) (co
162162
return nil, apierrors.NewInternalServerError("Error validating client credentials").WithInternalError(err)
163163
}
164164

165-
// Validate authentication using centralized logic
166-
if err := oauthserver.ValidateClientAuthentication(client, clientSecret); err != nil {
165+
// Validate that the auth method used matches the client's registered method
166+
if err := oauthserver.ValidateClientAuthMethod(client, creds.AuthMethod); err != nil {
167+
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "%s", err.Error())
168+
}
169+
170+
// Validate authentication using centralized logic (secret verification)
171+
if err := oauthserver.ValidateClientAuthentication(client, creds.ClientSecret); err != nil {
167172
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "%s", err.Error())
168173
}
169174

internal/api/oauthserver/auth.go

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,41 @@ import (
88
"io"
99
"net/http"
1010
"strings"
11+
12+
"github.com/supabase/auth/internal/models"
1113
)
1214

15+
// ClientCredentials represents the extracted client credentials and authentication method used
16+
type ClientCredentials struct {
17+
ClientID string
18+
ClientSecret string
19+
AuthMethod string
20+
}
21+
1322
// ExtractClientCredentials extracts OAuth client credentials from the request
1423
// Supports Basic auth header, form body parameters, and JSON body parameters
15-
func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, err error) {
24+
func ExtractClientCredentials(r *http.Request) (*ClientCredentials, error) {
25+
creds := &ClientCredentials{}
26+
1627
// First, try Basic auth header: Authorization: Basic base64(client_id:client_secret)
1728
authHeader := r.Header.Get("Authorization")
1829
if authHeader != "" && strings.HasPrefix(authHeader, "Basic ") {
1930
encoded := strings.TrimPrefix(authHeader, "Basic ")
2031
decoded, err := base64.StdEncoding.DecodeString(encoded)
2132
if err != nil {
22-
return "", "", errors.New("invalid basic auth encoding")
33+
return nil, errors.New("invalid basic auth encoding")
2334
}
2435

2536
credentials := string(decoded)
2637
parts := strings.SplitN(credentials, ":", 2)
2738
if len(parts) != 2 {
28-
return "", "", errors.New("invalid basic auth format")
39+
return nil, errors.New("invalid basic auth format")
2940
}
3041

31-
return parts[0], parts[1], nil
42+
creds.ClientID = parts[0]
43+
creds.ClientSecret = parts[1]
44+
creds.AuthMethod = models.TokenEndpointAuthMethodClientSecretBasic
45+
return creds, nil
3246
}
3347

3448
// Check Content-Type to determine how to parse body parameters
@@ -37,7 +51,7 @@ func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, e
3751
// Parse JSON body
3852
body, err := io.ReadAll(r.Body)
3953
if err != nil {
40-
return "", "", errors.New("failed to read request body")
54+
return nil, errors.New("failed to read request body")
4155
}
4256
// Restore the body so other handlers can read it
4357
r.Body = io.NopCloser(bytes.NewBuffer(body))
@@ -47,25 +61,32 @@ func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, e
4761
ClientSecret string `json:"client_secret"`
4862
}
4963
if err := json.Unmarshal(body, &jsonData); err != nil {
50-
return "", "", errors.New("failed to parse JSON body")
64+
return nil, errors.New("failed to parse JSON body")
5165
}
5266

53-
clientID = jsonData.ClientID
54-
clientSecret = jsonData.ClientSecret
67+
creds.ClientID = jsonData.ClientID
68+
creds.ClientSecret = jsonData.ClientSecret
5569
} else {
5670
// Fall back to form parameters
5771
if err := r.ParseForm(); err != nil {
58-
return "", "", errors.New("failed to parse form")
72+
return nil, errors.New("failed to parse form")
5973
}
6074

61-
clientID = r.FormValue("client_id")
62-
clientSecret = r.FormValue("client_secret")
75+
creds.ClientID = r.FormValue("client_id")
76+
creds.ClientSecret = r.FormValue("client_secret")
6377
}
6478

6579
// return error if client_id is not provided
66-
if clientID == "" {
67-
return "", "", errors.New("client_id is required")
80+
if creds.ClientID == "" {
81+
return nil, errors.New("client_id is required")
82+
}
83+
84+
// Determine auth method based on presence of client_secret in body
85+
if creds.ClientSecret != "" {
86+
creds.AuthMethod = models.TokenEndpointAuthMethodClientSecretPost
87+
} else {
88+
creds.AuthMethod = models.TokenEndpointAuthMethodNone
6889
}
6990

70-
return clientID, clientSecret, nil
91+
return creds, nil
7192
}

internal/api/oauthserver/client_auth.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,15 @@ func GetAllValidAuthMethods() []string {
108108
models.TokenEndpointAuthMethodClientSecretPost,
109109
}
110110
}
111+
112+
// ValidateClientAuthMethod validates the authentication method used matches the registered method
113+
func ValidateClientAuthMethod(client *models.OAuthServerClient, usedMethod string) error {
114+
registeredMethod := client.GetTokenEndpointAuthMethod()
115+
116+
if usedMethod != registeredMethod {
117+
return fmt.Errorf("invalid authentication method: client is registered for '%s' but '%s' was used",
118+
registeredMethod, usedMethod)
119+
}
120+
121+
return nil
122+
}

internal/api/oauthserver/client_auth_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,100 @@ func TestGetAllValidAuthMethods(t *testing.T) {
395395
}
396396
}
397397

398+
func TestValidateClientAuthMethod(t *testing.T) {
399+
tests := []struct {
400+
name string
401+
client *models.OAuthServerClient
402+
usedMethod string
403+
expectError bool
404+
errorContains string
405+
}{
406+
{
407+
name: "client registered for basic should accept basic",
408+
client: &models.OAuthServerClient{
409+
ID: uuid.Must(uuid.NewV4()),
410+
ClientType: models.OAuthServerClientTypeConfidential,
411+
TokenEndpointAuthMethod: models.TokenEndpointAuthMethodClientSecretBasic,
412+
},
413+
usedMethod: models.TokenEndpointAuthMethodClientSecretBasic,
414+
expectError: false,
415+
},
416+
{
417+
name: "client registered for post should accept post",
418+
client: &models.OAuthServerClient{
419+
ID: uuid.Must(uuid.NewV4()),
420+
ClientType: models.OAuthServerClientTypeConfidential,
421+
TokenEndpointAuthMethod: models.TokenEndpointAuthMethodClientSecretPost,
422+
},
423+
usedMethod: models.TokenEndpointAuthMethodClientSecretPost,
424+
expectError: false,
425+
},
426+
{
427+
name: "client registered for basic should reject post",
428+
client: &models.OAuthServerClient{
429+
ID: uuid.Must(uuid.NewV4()),
430+
ClientType: models.OAuthServerClientTypeConfidential,
431+
TokenEndpointAuthMethod: models.TokenEndpointAuthMethodClientSecretBasic,
432+
},
433+
usedMethod: models.TokenEndpointAuthMethodClientSecretPost,
434+
expectError: true,
435+
errorContains: "invalid authentication method",
436+
},
437+
{
438+
name: "client registered for post should reject basic",
439+
client: &models.OAuthServerClient{
440+
ID: uuid.Must(uuid.NewV4()),
441+
ClientType: models.OAuthServerClientTypeConfidential,
442+
TokenEndpointAuthMethod: models.TokenEndpointAuthMethodClientSecretPost,
443+
},
444+
usedMethod: models.TokenEndpointAuthMethodClientSecretBasic,
445+
expectError: true,
446+
errorContains: "invalid authentication method",
447+
},
448+
{
449+
name: "public client registered for none should accept none",
450+
client: &models.OAuthServerClient{
451+
ID: uuid.Must(uuid.NewV4()),
452+
ClientType: models.OAuthServerClientTypePublic,
453+
TokenEndpointAuthMethod: models.TokenEndpointAuthMethodNone,
454+
},
455+
usedMethod: models.TokenEndpointAuthMethodNone,
456+
expectError: false,
457+
},
458+
{
459+
name: "public client registered for none should reject basic",
460+
client: &models.OAuthServerClient{
461+
ID: uuid.Must(uuid.NewV4()),
462+
ClientType: models.OAuthServerClientTypePublic,
463+
TokenEndpointAuthMethod: models.TokenEndpointAuthMethodNone,
464+
},
465+
usedMethod: models.TokenEndpointAuthMethodClientSecretBasic,
466+
expectError: true,
467+
errorContains: "invalid authentication method",
468+
},
469+
}
470+
471+
for _, tt := range tests {
472+
t.Run(tt.name, func(t *testing.T) {
473+
err := ValidateClientAuthMethod(tt.client, tt.usedMethod)
474+
475+
if tt.expectError {
476+
if err == nil {
477+
t.Errorf("ValidateClientAuthMethod() expected error but got nil")
478+
return
479+
}
480+
if tt.errorContains != "" && !containsString(err.Error(), tt.errorContains) {
481+
t.Errorf("ValidateClientAuthMethod() error = %v, expected to contain %v", err, tt.errorContains)
482+
}
483+
} else {
484+
if err != nil {
485+
t.Errorf("ValidateClientAuthMethod() expected no error but got: %v", err)
486+
}
487+
}
488+
})
489+
}
490+
}
491+
398492
// Helper function to check if a string contains a substring
399493
func containsString(s, substr string) bool {
400494
return len(s) >= len(substr) && (s == substr || (len(s) > len(substr) &&

internal/api/oauthserver/handlers.go

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,13 @@ type OAuthServerClientListResponse struct {
5151

5252
// oauthServerClientToResponse converts a model to response format
5353
func oauthServerClientToResponse(client *models.OAuthServerClient) *OAuthServerClientResponse {
54-
// Set token endpoint auth methods based on client type
55-
var tokenEndpointAuthMethods string
56-
// TODO(cemal) :: Remove this once we have the token endpoint auth method stored in the database
57-
if client.IsPublic() {
58-
// Public clients don't use client authentication
59-
tokenEndpointAuthMethods = models.TokenEndpointAuthMethodNone
60-
} else {
61-
// Confidential clients use client secret authentication
62-
tokenEndpointAuthMethods = models.TokenEndpointAuthMethodClientSecretBasic
63-
}
64-
6554
response := &OAuthServerClientResponse{
6655
ClientID: client.ID.String(),
6756
ClientType: client.ClientType,
6857

6958
// OAuth 2.1 DCR fields
7059
RedirectURIs: client.GetRedirectURIs(),
71-
TokenEndpointAuthMethod: tokenEndpointAuthMethods,
60+
TokenEndpointAuthMethod: client.GetTokenEndpointAuthMethod(),
7261
GrantTypes: client.GetGrantTypes(),
7362
ResponseTypes: []string{"code"}, // Always "code" in OAuth 2.1
7463
ClientName: utilities.StringValue(client.ClientName),

internal/api/oauthserver/service.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,29 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer
263263
// Determine client type using centralized logic
264264
clientType := DetermineClientType(params.ClientType, params.TokenEndpointAuthMethod)
265265

266+
// Determine token_endpoint_auth_method
267+
// If explicitly provided, use it; otherwise set default based on client type
268+
// Per RFC 7591: "If unspecified or omitted, the default is 'client_secret_basic'"
269+
// For public clients, the default is 'none' since they don't have a client secret
270+
tokenEndpointAuthMethod := params.TokenEndpointAuthMethod
271+
if tokenEndpointAuthMethod == "" {
272+
if clientType == models.OAuthServerClientTypePublic {
273+
tokenEndpointAuthMethod = models.TokenEndpointAuthMethodNone
274+
} else {
275+
tokenEndpointAuthMethod = models.TokenEndpointAuthMethodClientSecretBasic
276+
}
277+
}
278+
266279
db := s.db.WithContext(ctx)
267280

268281
client := &models.OAuthServerClient{
269-
ID: uuid.Must(uuid.NewV4()),
270-
RegistrationType: params.RegistrationType,
271-
ClientType: clientType,
272-
ClientName: utilities.StringPtr(params.ClientName),
273-
ClientURI: utilities.StringPtr(params.ClientURI),
274-
LogoURI: utilities.StringPtr(params.LogoURI),
282+
ID: uuid.Must(uuid.NewV4()),
283+
RegistrationType: params.RegistrationType,
284+
ClientType: clientType,
285+
TokenEndpointAuthMethod: tokenEndpointAuthMethod,
286+
ClientName: utilities.StringPtr(params.ClientName),
287+
ClientURI: utilities.StringPtr(params.ClientURI),
288+
LogoURI: utilities.StringPtr(params.LogoURI),
275289
}
276290

277291
client.SetRedirectURIs(params.RedirectURIs)

internal/models/oauth_client.go

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"database/sql"
55
"fmt"
66
"net/url"
7+
"slices"
78
"strings"
89
"time"
910

@@ -28,10 +29,11 @@ const (
2829

2930
// OAuthServerClient represents an OAuth client application registered with this OAuth server
3031
type OAuthServerClient struct {
31-
ID uuid.UUID `json:"client_id" db:"id"`
32-
ClientSecretHash string `json:"-" db:"client_secret_hash"`
33-
RegistrationType string `json:"registration_type" db:"registration_type"`
34-
ClientType string `json:"client_type" db:"client_type"`
32+
ID uuid.UUID `json:"client_id" db:"id"`
33+
ClientSecretHash string `json:"-" db:"client_secret_hash"`
34+
RegistrationType string `json:"registration_type" db:"registration_type"`
35+
ClientType string `json:"client_type" db:"client_type"`
36+
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method" db:"token_endpoint_auth_method"`
3537

3638
RedirectURIs string `json:"-" db:"redirect_uris"`
3739
GrantTypes string `json:"grant_types" db:"grant_types"`
@@ -82,6 +84,34 @@ func (c *OAuthServerClient) Validate() error {
8284
return fmt.Errorf("client_secret is not allowed for public clients, use PKCE instead")
8385
}
8486

87+
// Apply default token_endpoint_auth_method per RFC 7591:
88+
// "If unspecified or omitted, the default is 'client_secret_basic'"
89+
// For public clients, the default is 'none' since they don't have a client secret
90+
if c.TokenEndpointAuthMethod == "" {
91+
if c.ClientType == OAuthServerClientTypePublic {
92+
c.TokenEndpointAuthMethod = TokenEndpointAuthMethodNone
93+
} else {
94+
c.TokenEndpointAuthMethod = TokenEndpointAuthMethodClientSecretBasic
95+
}
96+
}
97+
98+
// Validate token_endpoint_auth_method
99+
validMethods := []string{TokenEndpointAuthMethodNone, TokenEndpointAuthMethodClientSecretBasic, TokenEndpointAuthMethodClientSecretPost}
100+
if !slices.Contains(validMethods, c.TokenEndpointAuthMethod) {
101+
return fmt.Errorf("token_endpoint_auth_method must be one of: %s, %s, %s",
102+
TokenEndpointAuthMethodNone, TokenEndpointAuthMethodClientSecretBasic, TokenEndpointAuthMethodClientSecretPost)
103+
}
104+
105+
// Public clients must use 'none'
106+
if c.ClientType == OAuthServerClientTypePublic && c.TokenEndpointAuthMethod != TokenEndpointAuthMethodNone {
107+
return fmt.Errorf("public clients must use token_endpoint_auth_method '%s'", TokenEndpointAuthMethodNone)
108+
}
109+
110+
// Confidential clients cannot use 'none'
111+
if c.ClientType == OAuthServerClientTypeConfidential && c.TokenEndpointAuthMethod == TokenEndpointAuthMethodNone {
112+
return fmt.Errorf("confidential clients cannot use token_endpoint_auth_method '%s'", TokenEndpointAuthMethodNone)
113+
}
114+
85115
return nil
86116
}
87117

@@ -121,6 +151,11 @@ func (c *OAuthServerClient) IsConfidential() bool {
121151
return c.ClientType == OAuthServerClientTypeConfidential
122152
}
123153

154+
// GetTokenEndpointAuthMethod returns the token endpoint auth method
155+
func (c *OAuthServerClient) GetTokenEndpointAuthMethod() string {
156+
return c.TokenEndpointAuthMethod
157+
}
158+
124159
// IsGrantTypeAllowed returns true if the client is allowed to use the specified grant type
125160
func (c *OAuthServerClient) IsGrantTypeAllowed(grantType string) bool {
126161
allowedTypes := c.GetGrantTypes()

0 commit comments

Comments
 (0)