@@ -4,10 +4,13 @@ import (
44 "crypto/tls"
55 "errors"
66 "fmt"
7+ "io"
78 "net/http"
89 "net/http/httptest"
910 "net/url"
11+ "strings"
1012 "testing"
13+ "testing/iotest"
1114
1215 "github.com/stretchr/testify/assert"
1316 "github.com/versent/saml2aws/v2/pkg/cfg"
@@ -63,6 +66,65 @@ func TestGetStateTokenFromOktaPageBody(t *testing.T) {
6366 }
6467}
6568
69+ func TestExtractSessionToken (t * testing.T ) {
70+ tests := []struct {
71+ name string
72+ r io.Reader
73+ expectedToken string
74+ expectedError string
75+ }{
76+ {
77+ name : "response with session token" ,
78+ r : strings .NewReader (`{"sessionToken": "xxxx"}` ),
79+ expectedToken : "xxxx" ,
80+ },
81+ {
82+ name : "response with no session token but with status" ,
83+ r : strings .NewReader (`{"status": "invalid password"}` ),
84+ expectedError : "response does not contain session token, received status is: \" invalid password\" " ,
85+ },
86+ {
87+ name : "response with no session token and no status" ,
88+ r : strings .NewReader (`{}` ),
89+ expectedError : "response does not contain session token" ,
90+ },
91+ {
92+ name : "response is not even json" ,
93+ r : strings .NewReader (`const x = {}` ),
94+ expectedError : "response does not contain session token" ,
95+ },
96+ {
97+ name : "reader returns an error" ,
98+ r : iotest .ErrReader (fmt .Errorf ("failed to read" )),
99+ expectedError : "error retrieving body from response: failed to read" ,
100+ },
101+ }
102+
103+ for _ , tc := range tests {
104+ t .Run (tc .name , func (t * testing.T ) {
105+ resp , err := extractSessionToken (tc .r )
106+ if tc .expectedError != "" {
107+ if err == nil {
108+ t .Fatalf ("Expected error, but got null" )
109+ }
110+ if err .Error () != tc .expectedError {
111+ t .Fatalf ("Expected error %q, but got %q" ,
112+ err .Error (), tc .expectedError ,
113+ )
114+ }
115+ }
116+ if tc .expectedToken != "" {
117+ if err != nil {
118+ t .Fatalf ("Expected token %q, but got error %v" , tc .expectedToken , err )
119+ }
120+ if resp != tc .expectedToken {
121+ t .Fatalf ("Expected token %q, but got %q" , tc .expectedToken , resp )
122+ }
123+ }
124+ })
125+ }
126+ }
127+
66128func TestGetMfaChallengeContext (t * testing.T ) {
67129 ts := httptest .NewTLSServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
68130 _ , _ = w .Write ([]byte ("OK" ))
0 commit comments