1 package ssocreds
2
3 import (
4 "context"
5 "fmt"
6 "path/filepath"
7 "strings"
8 "testing"
9 "time"
10
11 "github.com/aws/aws-sdk-go-v2/aws"
12 "github.com/aws/aws-sdk-go-v2/internal/sdk"
13 "github.com/aws/aws-sdk-go-v2/service/sso"
14 "github.com/aws/aws-sdk-go-v2/service/sso/types"
15 )
16
17 type mockClient struct {
18 t *testing.T
19
20 Output *sso.GetRoleCredentialsOutput
21 Err error
22
23 ExpectedAccountID string
24 ExpectedAccessToken string
25 ExpectedRoleName string
26
27 Response func(mockClient) (*sso.GetRoleCredentialsOutput, error)
28 }
29
30 func (m mockClient) GetRoleCredentials(ctx context.Context, params *sso.GetRoleCredentialsInput, optFns ...func(options *sso.Options)) (out *sso.GetRoleCredentialsOutput, err error) {
31 m.t.Helper()
32
33 if len(m.ExpectedAccountID) > 0 {
34 if diff := cmpDiff(m.ExpectedAccountID, aws.ToString(params.AccountId)); len(diff) > 0 {
35 m.t.Error(diff)
36 }
37 }
38
39 if len(m.ExpectedAccessToken) > 0 {
40 if diff := cmpDiff(m.ExpectedAccessToken, aws.ToString(params.AccessToken)); len(diff) > 0 {
41 m.t.Error(diff)
42 }
43 }
44
45 if len(m.ExpectedRoleName) > 0 {
46 if diff := cmpDiff(m.ExpectedRoleName, aws.ToString(params.RoleName)); len(diff) > 0 {
47 m.t.Error(diff)
48 }
49 }
50
51 if m.Response == nil {
52 return out, err
53 }
54 return m.Response(m)
55 }
56
57 func TestProvider(t *testing.T) {
58 origHomeDir := osUserHomeDur
59 defer func() {
60 osUserHomeDur = origHomeDir
61 }()
62
63 osUserHomeDur = func() string {
64 return "testdata"
65 }
66
67 restoreTime := sdk.TestingUseReferenceTime(time.Date(2021, 01, 19, 19, 50, 0, 0, time.UTC))
68 defer restoreTime()
69
70 cases := map[string]struct {
71 Client mockClient
72 AccountID string
73 Region string
74 RoleName string
75 StartURL string
76 Options []func(*Options)
77
78 ExpectedErr string
79 ExpectedCredentials aws.Credentials
80 }{
81 "missing required parameter values": {
82 StartURL: "https://invalid-required",
83 ExpectedErr: "cached SSO token must contain accessToken and expiresAt fields",
84 },
85 "valid required parameter values": {
86 Client: mockClient{
87 ExpectedAccountID: "012345678901",
88 ExpectedRoleName: "TestRole",
89 ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
90 Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
91 return &sso.GetRoleCredentialsOutput{
92 RoleCredentials: &types.RoleCredentials{
93 AccessKeyId: aws.String("AccessKey"),
94 SecretAccessKey: aws.String("SecretKey"),
95 SessionToken: aws.String("SessionToken"),
96 Expiration: 1611177743123,
97 },
98 }, nil
99 },
100 },
101 AccountID: "012345678901",
102 Region: "us-west-2",
103 RoleName: "TestRole",
104 StartURL: "https://valid-required-only",
105 ExpectedCredentials: aws.Credentials{
106 AccessKeyID: "AccessKey",
107 SecretAccessKey: "SecretKey",
108 SessionToken: "SessionToken",
109 CanExpire: true,
110 Expires: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
111 Source: ProviderName,
112 },
113 },
114 "custom cached token file": {
115 Client: mockClient{
116 ExpectedAccountID: "012345678901",
117 ExpectedRoleName: "TestRole",
118 ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
119 Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
120 return &sso.GetRoleCredentialsOutput{
121 RoleCredentials: &types.RoleCredentials{
122 AccessKeyId: aws.String("AccessKey"),
123 SecretAccessKey: aws.String("SecretKey"),
124 SessionToken: aws.String("SessionToken"),
125 Expiration: 1611177743123,
126 },
127 }, nil
128 },
129 },
130 Options: []func(*Options){
131 func(o *Options) {
132 o.CachedTokenFilepath = filepath.Join("testdata", "valid_token.json")
133 },
134 },
135 AccountID: "012345678901",
136 Region: "us-west-2",
137 RoleName: "TestRole",
138 StartURL: "ignored value",
139 ExpectedCredentials: aws.Credentials{
140 AccessKeyID: "AccessKey",
141 SecretAccessKey: "SecretKey",
142 SessionToken: "SessionToken",
143 CanExpire: true,
144 Expires: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
145 Source: ProviderName,
146 },
147 },
148 "expired access token": {
149 StartURL: "https://expired",
150 ExpectedErr: "SSO session has expired or is invalid",
151 },
152 "api error": {
153 Client: mockClient{
154 ExpectedAccountID: "012345678901",
155 ExpectedRoleName: "TestRole",
156 ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
157 Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
158 return nil, fmt.Errorf("api error")
159 },
160 },
161 AccountID: "012345678901",
162 Region: "us-west-2",
163 RoleName: "TestRole",
164 StartURL: "https://valid-required-only",
165 ExpectedErr: "api error",
166 },
167 }
168
169 for name, tt := range cases {
170 t.Run(name, func(t *testing.T) {
171 tt.Client.t = t
172
173 provider := New(tt.Client, tt.AccountID, tt.RoleName, tt.StartURL, tt.Options...)
174
175 credentials, err := provider.Retrieve(context.Background())
176 if tt.ExpectedErr != "" {
177 if err == nil {
178 t.Fatalf("expect %v error, got none", tt.ExpectedErr)
179 }
180 if e, a := tt.ExpectedErr, err.Error(); !strings.Contains(a, e) {
181 t.Fatalf("expect %v error, got %v", e, a)
182 }
183 return
184 }
185 if err != nil {
186 t.Fatalf("expect no error, got %v", err)
187 }
188
189 if diff := cmpDiff(tt.ExpectedCredentials, credentials); len(diff) > 0 {
190 t.Errorf(diff)
191 }
192 })
193 }
194 }
195
View as plain text