1
2
3
4 package ssocreds
5
6 import (
7 "context"
8 "fmt"
9 "io/ioutil"
10 "os"
11 "path/filepath"
12 "reflect"
13 "strings"
14 "testing"
15 "time"
16
17 "github.com/aws/aws-sdk-go-v2/aws"
18 "github.com/aws/aws-sdk-go-v2/internal/sdk"
19 "github.com/aws/aws-sdk-go-v2/service/ssooidc"
20 smithybearer "github.com/aws/smithy-go/auth/bearer"
21 )
22
23 func TestSSOTokenProvider(t *testing.T) {
24 restoreTime := sdk.TestingUseReferenceTime(time.Date(2021, 12, 21, 12, 21, 1, 0, time.UTC))
25 defer restoreTime()
26
27 tempDir, err := ioutil.TempDir(os.TempDir(), "aws-sdk-go-v2-"+t.Name())
28 if err != nil {
29 t.Fatalf("failed to create temporary test directory, %v", err)
30 }
31 defer func() {
32 if err := os.RemoveAll(tempDir); err != nil {
33 t.Errorf("failed to cleanup temporary test directory, %v", err)
34 }
35 }()
36
37 cases := map[string]struct {
38 setup func() error
39 postRetrieve func() error
40 client CreateTokenAPIClient
41 cacheFilePath string
42 optFns []func(*SSOTokenProviderOptions)
43
44 expectToken smithybearer.Token
45 expectErr string
46 }{
47 "no cache file": {
48 cacheFilePath: filepath.Join("testdata", "file_not_exists"),
49 expectErr: "failed to read cached SSO token file",
50 },
51 "invalid json cache file": {
52 cacheFilePath: filepath.Join("testdata", "invalid_json.json"),
53 expectErr: "failed to parse cached SSO token file",
54 },
55 "missing accessToken": {
56 cacheFilePath: filepath.Join("testdata", "missing_accessToken.json"),
57 expectErr: "must contain accessToken and expiresAt fields",
58 },
59 "missing expiresAt": {
60 cacheFilePath: filepath.Join("testdata", "missing_expiresAt.json"),
61 expectErr: "must contain accessToken and expiresAt fields",
62 },
63 "expired no clientSecret": {
64 cacheFilePath: filepath.Join("testdata", "missing_clientSecret.json"),
65 expectErr: "cached SSO token is expired, or not present",
66 },
67 "expired no clientId": {
68 cacheFilePath: filepath.Join("testdata", "missing_clientId.json"),
69 expectErr: "cached SSO token is expired, or not present",
70 },
71 "expired no refreshToken": {
72 cacheFilePath: filepath.Join("testdata", "missing_refreshToken.json"),
73 expectErr: "cached SSO token is expired, or not present",
74 },
75 "valid sso token": {
76 cacheFilePath: filepath.Join("testdata", "valid_token.json"),
77 expectToken: smithybearer.Token{
78 Value: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
79 CanExpire: true,
80 Expires: time.Date(2044, 4, 4, 7, 0, 1, 0, time.UTC),
81 },
82 },
83 "refresh expired token": {
84 setup: func() error {
85 testFile, err := os.ReadFile(filepath.Join("testdata", "expired_token.json"))
86 if err != nil {
87 return err
88 }
89
90 return os.WriteFile(filepath.Join(tempDir, "expired_token.json"), testFile, 0600)
91 },
92 postRetrieve: func() error {
93 actual, err := loadCachedToken(filepath.Join(tempDir, "expired_token.json"))
94 if err != nil {
95 return err
96
97 }
98 expect := token{
99 tokenKnownFields: tokenKnownFields{
100 AccessToken: "updated access token",
101 ExpiresAt: (*rfc3339)(aws.Time(time.Date(2021, 12, 21, 12, 31, 1, 0, time.UTC))),
102
103 RefreshToken: "updated refresh token",
104 ClientID: "client id",
105 ClientSecret: "client secret",
106 },
107 UnknownFields: map[string]interface{}{
108 "unknownField": "some value",
109 },
110 }
111
112 if diff := cmpDiff(expect, actual); diff != "" {
113 return fmt.Errorf("expect token file match\n%s", diff)
114 }
115 return nil
116 },
117 cacheFilePath: filepath.Join(tempDir, "expired_token.json"),
118 client: &mockCreateTokenAPIClient{
119 expectInput: &ssooidc.CreateTokenInput{
120 ClientId: aws.String("client id"),
121 ClientSecret: aws.String("client secret"),
122 RefreshToken: aws.String("refresh token"),
123 GrantType: aws.String("refresh_token"),
124 },
125 output: &ssooidc.CreateTokenOutput{
126 AccessToken: aws.String("updated access token"),
127 ExpiresIn: 600,
128 RefreshToken: aws.String("updated refresh token"),
129 },
130 },
131 expectToken: smithybearer.Token{
132 Value: "updated access token",
133 CanExpire: true,
134 Expires: time.Date(2021, 12, 21, 12, 31, 1, 0, time.UTC),
135 },
136 },
137 "fail refresh expired token": {
138 setup: func() error {
139 testFile, err := os.ReadFile(filepath.Join("testdata", "expired_token.json"))
140 if err != nil {
141 return err
142 }
143 return os.WriteFile(filepath.Join(tempDir, "expired_token.json"), testFile, 0600)
144 },
145 postRetrieve: func() error {
146 actual, err := loadCachedToken(filepath.Join(tempDir, "expired_token.json"))
147 if err != nil {
148 return err
149
150 }
151 expect := token{
152 tokenKnownFields: tokenKnownFields{
153 AccessToken: "access token",
154 ExpiresAt: (*rfc3339)(aws.Time(time.Date(2021, 12, 21, 12, 21, 1, 0, time.UTC))),
155
156 RefreshToken: "refresh token",
157 ClientID: "client id",
158 ClientSecret: "client secret",
159 },
160 }
161
162 if diff := cmpDiff(expect, actual); diff != "" {
163 return fmt.Errorf("expect token file match\n%s", diff)
164 }
165 return nil
166 },
167 cacheFilePath: filepath.Join(tempDir, "expired_token.json"),
168 client: &mockCreateTokenAPIClient{
169 err: fmt.Errorf("sky is falling"),
170 },
171 expectErr: "unable to refresh SSO token, sky is falling",
172 },
173 }
174
175 for name, c := range cases {
176 t.Run(name, func(t *testing.T) {
177 if c.setup != nil {
178 if err := c.setup(); err != nil {
179 t.Fatalf("failed to setup test, %v", err)
180 }
181 }
182 provider := NewSSOTokenProvider(c.client, c.cacheFilePath, c.optFns...)
183
184 token, err := provider.RetrieveBearerToken(context.Background())
185 if c.expectErr != "" {
186 if err == nil {
187 t.Fatalf("expect %v error, got none", c.expectErr)
188 }
189 if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) {
190 t.Fatalf("expect %v error, got %v", e, a)
191 }
192 return
193 }
194 if err != nil {
195 t.Fatalf("expect no error, got %v", err)
196 }
197
198 if diff := cmpDiff(c.expectToken, token); diff != "" {
199 t.Errorf("expect token match\n%s", diff)
200 }
201
202 if c.postRetrieve != nil {
203 if err := c.postRetrieve(); err != nil {
204 t.Fatalf("post retrieve failed, %v", err)
205 }
206 }
207 })
208 }
209 }
210
211 type mockCreateTokenAPIClient struct {
212 expectInput *ssooidc.CreateTokenInput
213 output *ssooidc.CreateTokenOutput
214 err error
215 }
216
217 func (c *mockCreateTokenAPIClient) CreateToken(
218 ctx context.Context, input *ssooidc.CreateTokenInput, optFns ...func(*ssooidc.Options)) (
219 *ssooidc.CreateTokenOutput, error,
220 ) {
221 if c.expectInput != nil {
222 if diff := cmpDiff(c.expectInput, input); diff != "" {
223 return nil, fmt.Errorf("expect input match\n%s", diff)
224 }
225 }
226
227 return c.output, c.err
228 }
229
230 func cmpDiff(e, a interface{}) string {
231 if !reflect.DeepEqual(e, a) {
232 return fmt.Sprintf("%v != %v", e, a)
233 }
234 return ""
235 }
236
View as plain text