1 package processcreds
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "io/ioutil"
11 "os"
12 "path/filepath"
13 "runtime"
14 "strings"
15 "testing"
16 "time"
17 )
18
19 func TestProviderBadCommand(t *testing.T) {
20 provider := NewProvider("/bad/process")
21 _, err := provider.Retrieve(context.Background())
22 var pe *ProviderError
23 if ok := errors.As(err, &pe); !ok {
24 t.Fatalf("expect error to be of type %T", pe)
25 }
26 if e, a := "error in credential_process", pe.Error(); !strings.Contains(a, e) {
27 t.Errorf("expected %v, got %v", e, a)
28 }
29 }
30
31 func TestProviderMoreEmptyCommands(t *testing.T) {
32 provider := NewProvider("")
33 _, err := provider.Retrieve(context.Background())
34 var pe *ProviderError
35 if ok := errors.As(err, &pe); !ok {
36 t.Fatalf("expect error to be of type %T", pe)
37 }
38 if e, a := "failed to prepare command", pe.Error(); !strings.Contains(a, e) {
39 t.Errorf("expected %v, got %v", e, a)
40 }
41 }
42
43 func TestProviderExpectErrors(t *testing.T) {
44 provider := NewProvider(
45 fmt.Sprintf(
46 "%s %s",
47 getOSCat(),
48 filepath.Join("testdata", "malformed.json"),
49 ))
50 _, err := provider.Retrieve(context.Background())
51 var pe *ProviderError
52 if ok := errors.As(err, &pe); !ok {
53 t.Fatalf("expect error to be of type %T", pe)
54 }
55 if e, a := "parse failed of process output", pe.Error(); !strings.Contains(a, e) {
56 t.Errorf("expected %v, got %v", e, a)
57 }
58
59 provider = NewProvider(
60 fmt.Sprintf("%s %s",
61 getOSCat(),
62 filepath.Join("testdata", "wrongversion.json"),
63 ))
64 _, err = provider.Retrieve(context.Background())
65 if ok := errors.As(err, &pe); !ok {
66 t.Fatalf("expect error to be of type %T", pe)
67 }
68 if e, a := "wrong version in process output", pe.Error(); !strings.Contains(a, e) {
69 t.Errorf("expected %v, got %v", e, a)
70 }
71
72 provider = NewProvider(
73 fmt.Sprintf(
74 "%s %s",
75 getOSCat(),
76 filepath.Join("testdata", "missingkey.json"),
77 ))
78 _, err = provider.Retrieve(context.Background())
79 if ok := errors.As(err, &pe); !ok {
80 t.Fatalf("expect error to be of type %T", pe)
81 }
82 if e, a := "missing AccessKeyId", pe.Error(); !strings.Contains(a, e) {
83 t.Errorf("expected %v, got %v", e, a)
84 }
85
86 provider = NewProvider(
87 fmt.Sprintf(
88 "%s %s",
89 getOSCat(),
90 filepath.Join("testdata", "missingsecret.json"),
91 ))
92 _, err = provider.Retrieve(context.Background())
93 if ok := errors.As(err, &pe); !ok {
94 t.Fatalf("expect error to be of type %T", pe)
95 }
96 if e, a := "missing SecretAccessKey", pe.Error(); !strings.Contains(a, e) {
97 t.Errorf("expected %v, got %v", e, a)
98 }
99 }
100
101 func TestProviderTimeout(t *testing.T) {
102 command := "/bin/sleep 2"
103 if runtime.GOOS == "windows" {
104
105 command = "ping -n 2 127.0.0.1>nul"
106 }
107
108 provider := NewProvider(command, func(options *Options) {
109 options.Timeout = time.Duration(1) * time.Second
110 })
111 _, err := provider.Retrieve(context.Background())
112 var pe *ProviderError
113 if ok := errors.As(err, &pe); !ok {
114 t.Fatalf("expect error to be of type %T", pe)
115 }
116 if e, a := "credential process timed out", pe.Error(); !strings.Contains(a, e) {
117 t.Errorf("expected %v, got %v", e, a)
118 }
119 }
120
121 func TestProviderWithLongSessionToken(t *testing.T) {
122 provider := NewProvider(
123 fmt.Sprintf(
124 "%s %s",
125 getOSCat(),
126 filepath.Join("testdata", "longsessiontoken.json"),
127 ))
128 v, err := provider.Retrieve(context.Background())
129 if err != nil {
130 t.Errorf("expected %v, got %v", "no error", err)
131 }
132
133
134 e := "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
135 if a := v.SessionToken; e != a {
136 t.Errorf("expected %v, got %v", e, a)
137 }
138 }
139
140 type credentialTest struct {
141 Version int
142 AccessKeyID string `json:"AccessKeyId"`
143 SecretAccessKey string
144 Expiration string
145 }
146
147 func TestProviderStatic(t *testing.T) {
148
149 provider := NewProvider(
150 fmt.Sprintf(
151 "%s %s",
152 getOSCat(),
153 filepath.Join("testdata", "static.json"),
154 ))
155 v, err := provider.Retrieve(context.Background())
156 if err != nil {
157 t.Errorf("expected %v, got %v", "no error", err)
158 }
159 if v.CanExpire != false {
160 t.Errorf("expected %v, got %v", "static credentials/not expired", "can expire")
161 }
162
163 }
164
165 func TestProviderNotExpired(t *testing.T) {
166
167 exp := &credentialTest{}
168 exp.Version = 1
169 exp.AccessKeyID = "accesskey"
170 exp.SecretAccessKey = "secretkey"
171 exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
172 b, err := json.Marshal(exp)
173 if err != nil {
174 t.Errorf("expected %v, got %v", "no error", err)
175 }
176
177 tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_expiring")
178 if err != nil {
179 t.Errorf("expected %v, got %v", "no error", err)
180 }
181 if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
182 t.Errorf("expected %v, got %v", "no error", err)
183 }
184 defer func() {
185 if err = tmpFile.Close(); err != nil {
186 t.Errorf("expected %v, got %v", "no error", err)
187 }
188 if err = os.Remove(tmpFile.Name()); err != nil {
189 t.Errorf("expected %v, got %v", "no error", err)
190 }
191 }()
192 provider := NewProvider(
193 fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
194 v, err := provider.Retrieve(context.Background())
195 if err != nil {
196 t.Errorf("expected %v, got %v", "no error", err)
197 }
198 if v.Expired() {
199 t.Errorf("expected %v, got %v", "not expired", "expired")
200 }
201 }
202
203 func TestProviderExpired(t *testing.T) {
204
205 exp := &credentialTest{}
206 exp.Version = 1
207 exp.AccessKeyID = "accesskey"
208 exp.SecretAccessKey = "secretkey"
209 exp.Expiration = time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339)
210 b, err := json.Marshal(exp)
211 if err != nil {
212 t.Errorf("expected %v, got %v", "no error", err)
213 }
214
215 tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_expired")
216 if err != nil {
217 t.Errorf("expected %v, got %v", "no error", err)
218 }
219 if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
220 t.Errorf("expected %v, got %v", "no error", err)
221 }
222 defer func() {
223 if err = tmpFile.Close(); err != nil {
224 t.Errorf("expected %v, got %v", "no error", err)
225 }
226 if err = os.Remove(tmpFile.Name()); err != nil {
227 t.Errorf("expected %v, got %v", "no error", err)
228 }
229 }()
230 provider := NewProvider(
231 fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
232 v, err := provider.Retrieve(context.Background())
233 if err != nil {
234 t.Errorf("expected %v, got %v", "no error", err)
235 }
236 if !v.Expired() {
237 t.Errorf("expected %v, got %v", "expired", "not expired")
238 }
239 }
240
241 func TestProviderForceExpire(t *testing.T) {
242
243
244
245 exp := &credentialTest{}
246 exp.Version = 1
247 exp.AccessKeyID = "accesskey"
248 exp.SecretAccessKey = "secretkey"
249 exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
250 b, err := json.Marshal(exp)
251 if err != nil {
252 t.Errorf("expected %v, got %v", "no error", err)
253 }
254 tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_force_expire")
255 if err != nil {
256 t.Errorf("expected %v, got %v", "no error", err)
257 }
258 if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
259 t.Errorf("expected %v, got %v", "no error", err)
260 }
261 defer func() {
262 if err = tmpFile.Close(); err != nil {
263 t.Errorf("expected %v, got %v", "no error", err)
264 }
265 if err = os.Remove(tmpFile.Name()); err != nil {
266 t.Errorf("expected %v, got %v", "no error", err)
267 }
268 }()
269
270
271 provider := NewProvider(
272 fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
273 v, err := provider.Retrieve(context.Background())
274 if err != nil {
275 t.Errorf("expected %v, got %v", "no error", err)
276 }
277 if v.Expired() {
278 t.Errorf("expected %v, got %v", "not expired", "expired")
279 }
280
281
282 v, err = provider.Retrieve(context.Background())
283 if err != nil {
284 t.Errorf("expected %v, got %v", "no error", err)
285 }
286 if v.Expired() {
287 t.Errorf("expected %v, got %v", "not expired", "expired")
288 }
289 }
290
291 func TestProviderAltConstruct(t *testing.T) {
292 cmdBuilder := DefaultNewCommandBuilder{Args: []string{
293 fmt.Sprintf("%s %s", getOSCat(),
294 filepath.Join("testdata", "static.json"),
295 ),
296 }}
297
298 provider := NewProviderCommand(cmdBuilder, func(options *Options) {
299 options.Timeout = time.Duration(1) * time.Second
300 })
301 v, err := provider.Retrieve(context.Background())
302 if err != nil {
303 t.Errorf("expected %v, got %v", "no error", err)
304 }
305 if v.CanExpire != false {
306 t.Errorf("expected %v, got %v", "static credentials/not expired", "expired")
307 }
308 }
309
310 func BenchmarkProcessProvider(b *testing.B) {
311 provider := NewProvider(
312 fmt.Sprintf(
313 "%s %s",
314 getOSCat(),
315 filepath.Join("testdata", "static.json"),
316 ))
317 _, err := provider.Retrieve(context.Background())
318 if err != nil {
319 b.Fatal(err)
320 }
321
322 b.ResetTimer()
323 for i := 0; i < b.N; i++ {
324 b.StartTimer()
325 _, err := provider.Retrieve(context.Background())
326 if err != nil {
327 b.Fatal(err)
328 }
329 b.StopTimer()
330 }
331 }
332
333 func getOSCat() string {
334 if runtime.GOOS == "windows" {
335 return "type"
336 }
337 return "cat"
338 }
339
View as plain text