Skip to content

Commit a0995ca

Browse files
authored
Merge pull request #823 from krzysztofdrys/fail-early-on-no-token
Return more informative error message incase there is no sessionToken
2 parents d02f727 + 225f0a9 commit a0995ca

File tree

2 files changed

+82
-8
lines changed

2 files changed

+82
-8
lines changed

pkg/provider/okta/okta.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -823,14 +823,7 @@ func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails,
823823
return "", errors.Wrap(err, "error retrieving token post response")
824824
}
825825

826-
body, err := io.ReadAll(res.Body)
827-
if err != nil {
828-
return "", errors.Wrap(err, "error retrieving body from response")
829-
}
830-
831-
resp = string(body)
832-
833-
return gjson.Get(resp, "sessionToken").String(), nil
826+
return extractSessionToken(res.Body)
834827

835828
case IdentifierPushMfa:
836829

@@ -1151,6 +1144,25 @@ func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails,
11511144
return "", errors.New("no mfa options provided")
11521145
}
11531146

1147+
func extractSessionToken(r io.Reader) (string, error) {
1148+
bb, err := io.ReadAll(r)
1149+
if err != nil {
1150+
return "", errors.Wrap(err, "error retrieving body from response")
1151+
}
1152+
1153+
resp := string(bb)
1154+
sessionToken := gjson.Get(resp, "sessionToken").String()
1155+
if sessionToken == "" {
1156+
status := gjson.Get(resp, "status").String()
1157+
if status != "" {
1158+
return "", errors.Errorf("response does not contain session token, received status is: %q", status)
1159+
}
1160+
return "", errors.Errorf("response does not contain session token")
1161+
}
1162+
1163+
return gjson.Get(resp, "sessionToken").String(), nil
1164+
}
1165+
11541166
func fidoWebAuthn(oc *Client, oktaOrgHost string, challengeContext *mfaChallengeContext, mfaOption int, stateToken string, mfaOptions []string, resp string) (string, error) {
11551167

11561168
var signedAssertion *SignedAssertion

pkg/provider/okta/okta_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ import (
44
"crypto/tls"
55
"errors"
66
"fmt"
7+
"io"
78
"net/http"
89
"net/http/httptest"
910
"net/url"
11+
"strings"
1012
"testing"
13+
"testing/iotest"
1114

1215
"github.com/stretchr/testify/assert"
1316
"github.com/versent/saml2aws/v2/pkg/cfg"
@@ -63,6 +66,65 @@ func TestGetStateTokenFromOktaPageBody(t *testing.T) {
6366
}
6467
}
6568

69+
func TestExtractSessionToken(t *testing.T) {
70+
tests := []struct {
71+
name string
72+
r io.Reader
73+
expectedToken string
74+
expectedError string
75+
}{
76+
{
77+
name: "response with session token",
78+
r: strings.NewReader(`{"sessionToken": "xxxx"}`),
79+
expectedToken: "xxxx",
80+
},
81+
{
82+
name: "response with no session token but with status",
83+
r: strings.NewReader(`{"status": "invalid password"}`),
84+
expectedError: "response does not contain session token, received status is: \"invalid password\"",
85+
},
86+
{
87+
name: "response with no session token and no status",
88+
r: strings.NewReader(`{}`),
89+
expectedError: "response does not contain session token",
90+
},
91+
{
92+
name: "response is not even json",
93+
r: strings.NewReader(`const x = {}`),
94+
expectedError: "response does not contain session token",
95+
},
96+
{
97+
name: "reader returns an error",
98+
r: iotest.ErrReader(fmt.Errorf("failed to read")),
99+
expectedError: "error retrieving body from response: failed to read",
100+
},
101+
}
102+
103+
for _, tc := range tests {
104+
t.Run(tc.name, func(t *testing.T) {
105+
resp, err := extractSessionToken(tc.r)
106+
if tc.expectedError != "" {
107+
if err == nil {
108+
t.Fatalf("Expected error, but got null")
109+
}
110+
if err.Error() != tc.expectedError {
111+
t.Fatalf("Expected error %q, but got %q",
112+
err.Error(), tc.expectedError,
113+
)
114+
}
115+
}
116+
if tc.expectedToken != "" {
117+
if err != nil {
118+
t.Fatalf("Expected token %q, but got error %v", tc.expectedToken, err)
119+
}
120+
if resp != tc.expectedToken {
121+
t.Fatalf("Expected token %q, but got %q", tc.expectedToken, resp)
122+
}
123+
}
124+
})
125+
}
126+
}
127+
66128
func TestGetMfaChallengeContext(t *testing.T) {
67129
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
68130
_, _ = w.Write([]byte("OK"))

0 commit comments

Comments
 (0)