1 package config
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "io/ioutil"
8 "net/http"
9 "net/http/httptest"
10 "os"
11 "strconv"
12 "testing"
13
14 "github.com/aws/aws-sdk-go-v2/aws"
15 awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
16 "github.com/aws/aws-sdk-go-v2/credentials"
17 "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
18 "github.com/aws/aws-sdk-go-v2/internal/awstesting"
19 "github.com/aws/aws-sdk-go-v2/internal/awstesting/unit"
20 "github.com/aws/smithy-go/logging"
21 )
22
23 func TestResolveCustomCABundle(t *testing.T) {
24 var options LoadOptions
25 var cfg aws.Config
26 cfg.HTTPClient = awshttp.NewBuildableClient()
27
28 WithCustomCABundle(bytes.NewReader(awstesting.TLSBundleCA))(&options)
29 configs := configs{options}
30
31 if err := resolveCustomCABundle(context.Background(), &cfg, configs); err != nil {
32 t.Fatalf("expect no error, got %v", err)
33 }
34
35 type transportGetter interface {
36 GetTransport() *http.Transport
37 }
38
39 trGetter := cfg.HTTPClient.(transportGetter)
40 tr := trGetter.GetTransport()
41 if tr.TLSClientConfig.RootCAs == nil {
42 t.Errorf("expect root CAs set")
43 }
44 }
45
46 func TestResolveCustomCABundle_ValidCA(t *testing.T) {
47 certFile, keyFile, caFile, err := awstesting.CreateTLSBundleFiles()
48 if err != nil {
49 t.Fatalf("failed to create cert temp files, %v", err)
50 }
51 defer func() {
52 awstesting.CleanupTLSBundleFiles(certFile, keyFile, caFile)
53 }()
54
55 serverAddr, err := awstesting.CreateTLSServer(certFile, keyFile, nil)
56 if err != nil {
57 t.Fatalf("failed to start TLS server, %v", err)
58 }
59
60 caPEM, err := ioutil.ReadFile(caFile)
61 if err != nil {
62 t.Fatalf("failed to read CA file, %v", err)
63 }
64
65 var options LoadOptions
66 var cfg aws.Config
67 cfg.HTTPClient = awshttp.NewBuildableClient()
68
69 WithCustomCABundle(bytes.NewReader(caPEM))(&options)
70 configs := configs{options}
71
72 if err := resolveCustomCABundle(context.Background(), &cfg, configs); err != nil {
73 t.Fatalf("expect no error, got %v", err)
74 }
75
76 req, _ := http.NewRequest("GET", serverAddr, nil)
77 resp, err := cfg.HTTPClient.Do(req)
78 if err != nil {
79 t.Fatalf("failed to make request to TLS server, %v", err)
80 }
81 resp.Body.Close()
82
83 if e, a := http.StatusOK, resp.StatusCode; e != a {
84 t.Errorf("expect %v status, got %v", e, a)
85 }
86 }
87
88 func TestResolveCustomCABundle_ErrorCustomClient(t *testing.T) {
89 var options LoadOptions
90 var cfg aws.Config
91
92 cfg.HTTPClient = &http.Client{}
93
94 WithCustomCABundle(bytes.NewReader(awstesting.TLSBundleCA))(&options)
95 configs := configs{options}
96
97 if err := resolveCustomCABundle(context.Background(), &cfg, configs); err == nil {
98 t.Fatalf("expect error, got none")
99 }
100 }
101
102 func TestResolveRegion(t *testing.T) {
103 var options LoadOptions
104 optFns := []func(options *LoadOptions) error{
105 WithRegion("ignored-region"),
106
107 WithRegion("mock-region"),
108 }
109
110 for _, optFn := range optFns {
111 optFn(&options)
112 }
113
114 configs := configs{options}
115
116 var cfg aws.Config
117
118 if err := resolveRegion(context.Background(), &cfg, configs); err != nil {
119 t.Fatalf("expect no error, got %v", err)
120 }
121
122 if e, a := "mock-region", cfg.Region; e != a {
123 t.Errorf("expect %v region, got %v", e, a)
124 }
125 }
126
127 func TestResolveAppID(t *testing.T) {
128 var options LoadOptions
129 optFns := []func(options *LoadOptions) error{
130 WithAppID("1234"),
131
132 WithAppID("5678"),
133 }
134
135 for _, optFn := range optFns {
136 optFn(&options)
137 }
138
139 configs := configs{options}
140
141 var cfg aws.Config
142
143 if err := resolveAppID(context.Background(), &cfg, configs); err != nil {
144 t.Fatalf("expect no error, got %v", err)
145 }
146
147 if e, a := "5678", cfg.AppID; e != a {
148 t.Errorf("expect %v app ID, got %v", e, a)
149 }
150 }
151
152 func TestResolveRequestMinCompressSizeBytes(t *testing.T) {
153 cases := map[string]struct {
154 RequestMinCompressSizeBytes *int64
155 ExpectMinBytes int64
156 }{
157 "min requet size of 100 bytes": {
158 RequestMinCompressSizeBytes: aws.Int64(100),
159 ExpectMinBytes: 100,
160 },
161 "min request size unset": {
162 ExpectMinBytes: 10240,
163 },
164 }
165
166 for name, c := range cases {
167 t.Run(name, func(t *testing.T) {
168 var options LoadOptions
169 optFns := []func(options *LoadOptions) error{
170 WithRequestMinCompressSizeBytes(c.RequestMinCompressSizeBytes),
171 }
172
173 for _, optFn := range optFns {
174 optFn(&options)
175 }
176
177 configs := configs{options}
178
179 var cfg aws.Config
180
181 if err := resolveRequestMinCompressSizeBytes(context.Background(), &cfg, configs); err != nil {
182 t.Fatalf("expect no error, got %v", err)
183 }
184
185 if e, a := c.ExpectMinBytes, cfg.RequestMinCompressSizeBytes; e != a {
186 t.Errorf("expect RequestMinCompressSizeBytes to be %v , got %v", e, a)
187 }
188 })
189 }
190 }
191
192 func TestResolveDisableRequestCompression(t *testing.T) {
193 cases := map[string]struct {
194 DisableRequestCompression *bool
195 ExpectDisable bool
196 }{
197 "disable request compression": {
198 DisableRequestCompression: aws.Bool(true),
199 ExpectDisable: true,
200 },
201 "disable request compression unset": {
202 ExpectDisable: false,
203 },
204 }
205
206 for name, c := range cases {
207 t.Run(name, func(t *testing.T) {
208 var options LoadOptions
209 optFns := []func(options *LoadOptions) error{
210 WithDisableRequestCompression(c.DisableRequestCompression),
211 }
212
213 for _, optFn := range optFns {
214 optFn(&options)
215 }
216
217 configs := configs{options}
218
219 var cfg aws.Config
220
221 if err := resolveDisableRequestCompression(context.Background(), &cfg, configs); err != nil {
222 t.Fatalf("expect no error, got %v", err)
223 }
224
225 if e, a := c.ExpectDisable, cfg.DisableRequestCompression; e != a {
226 t.Errorf("expect DisableRequestCompression to be %v , got %v", e, a)
227 }
228 })
229 }
230 }
231
232 func TestResolveCredentialsProvider(t *testing.T) {
233 var options LoadOptions
234 optFns := []func(options *LoadOptions) error{
235 WithCredentialsProvider(credentials.StaticCredentialsProvider{
236 Value: aws.Credentials{
237 AccessKeyID: "AKID",
238 SecretAccessKey: "SECRET",
239 Source: "valid",
240 }},
241 ),
242 }
243
244 for _, optFn := range optFns {
245 optFn(&options)
246 }
247
248 configs := configs{options}
249
250 var cfg aws.Config
251 cfg.Credentials = nil
252
253 if found, err := resolveCredentialProvider(context.Background(), &cfg, configs); err != nil {
254 t.Fatalf("expect no error, got %v", err)
255 } else if e, a := true, found; e != a {
256 t.Fatalf("expected %v, got %v", e, a)
257 }
258
259 _, ok := cfg.Credentials.(*aws.CredentialsCache)
260 if !ok {
261 t.Fatalf("expect resolved credentials to be wrapped in cache, was not, %T", cfg.Credentials)
262 }
263
264 creds, err := cfg.Credentials.Retrieve(context.Background())
265 if err != nil {
266 t.Fatalf("expect no error, got %v", err)
267 }
268
269 if e, a := "AKID", creds.AccessKeyID; e != a {
270 t.Errorf("expect %v key, got %v", e, a)
271 }
272 if e, a := "SECRET", creds.SecretAccessKey; e != a {
273 t.Errorf("expect %v secret, got %v", e, a)
274 }
275 if e, a := "valid", creds.Source; e != a {
276 t.Errorf("expect %v provider name, got %v", e, a)
277 }
278 }
279
280 func TestDefaultRegion(t *testing.T) {
281 ctx := context.Background()
282
283 var options LoadOptions
284 WithDefaultRegion("foo-region")(&options)
285
286 configs := configs{options}
287 cfg := unit.Config()
288
289 err := resolveDefaultRegion(ctx, &cfg, configs)
290 if err != nil {
291 t.Fatalf("expected no error, got %v", err)
292 }
293
294 if e, a := "mock-region", cfg.Region; e != a {
295 t.Errorf("expected %v, got %v", e, a)
296 }
297
298 cfg.Region = ""
299
300 err = resolveDefaultRegion(ctx, &cfg, configs)
301 if err != nil {
302 t.Fatalf("expected no error, got %v", err)
303 }
304
305 if e, a := "foo-region", cfg.Region; e != a {
306 t.Errorf("expected %v, got %v", e, a)
307 }
308 }
309
310 func TestResolveLogger(t *testing.T) {
311 cfg, err := LoadDefaultConfig(context.Background(), func(o *LoadOptions) error {
312 o.Logger = logging.Nop{}
313 return nil
314 })
315 if err != nil {
316 t.Fatalf("expect no error, got %v", err)
317 }
318
319 _, ok := cfg.Logger.(logging.Nop)
320 if !ok {
321 t.Error("unexpected logger type")
322 }
323 }
324
325 func TestResolveDefaultsMode(t *testing.T) {
326 cases := []struct {
327 Mode aws.DefaultsMode
328 ExpectedDefaultsMode aws.DefaultsMode
329 ExpectedRuntimeEnvironment aws.RuntimeEnvironment
330 WithIMDS func() *httptest.Server
331 Env map[string]string
332 }{
333 {
334 ExpectedDefaultsMode: aws.DefaultsModeLegacy,
335 },
336 {
337 Mode: aws.DefaultsModeStandard,
338 ExpectedDefaultsMode: aws.DefaultsModeStandard,
339 },
340 {
341 Mode: aws.DefaultsModeInRegion,
342 ExpectedDefaultsMode: aws.DefaultsModeInRegion,
343 },
344 {
345 Mode: aws.DefaultsModeCrossRegion,
346 ExpectedDefaultsMode: aws.DefaultsModeCrossRegion,
347 },
348 {
349 Mode: aws.DefaultsModeMobile,
350 ExpectedDefaultsMode: aws.DefaultsModeMobile,
351 },
352 {
353 Mode: aws.DefaultsModeAuto,
354 Env: map[string]string{
355 "AWS_EXECUTION_ENV": "envName",
356 "AWS_REGION": "us-west-2",
357 },
358 WithIMDS: func() *httptest.Server {
359 return httptest.NewServer(http.HandlerFunc(
360 func(w http.ResponseWriter, r *http.Request) {
361 if r.URL.Path == "/latest/dynamic/instance-identity/document" {
362 out, _ := json.Marshal(&imds.InstanceIdentityDocument{
363 Region: "us-west-2",
364 })
365 w.Write(out)
366 } else if r.URL.Path == "/latest/api/token" {
367 header := w.Header()
368
369 const ttlHeader = "X-Aws-Ec2-Metadata-Token-Ttl-Seconds"
370 header.Set(ttlHeader, r.Header.Get(ttlHeader))
371 w.Write([]byte("validToken"))
372 } else {
373 w.Write([]byte(""))
374 }
375 }))
376 },
377 ExpectedDefaultsMode: aws.DefaultsModeAuto,
378 ExpectedRuntimeEnvironment: aws.RuntimeEnvironment{
379 EnvironmentIdentifier: "envName",
380 Region: "us-west-2",
381 EC2InstanceMetadataRegion: "us-west-2",
382 },
383 },
384 {
385 Mode: aws.DefaultsModeAuto,
386 Env: map[string]string{
387 "AWS_EXECUTION_ENV": "envName",
388 "AWS_REGION": "us-west-2",
389 },
390 WithIMDS: func() *httptest.Server {
391 return httptest.NewServer(http.HandlerFunc(
392 func(w http.ResponseWriter, r *http.Request) {
393 w.WriteHeader(500)
394 }))
395 },
396 ExpectedDefaultsMode: aws.DefaultsModeAuto,
397 ExpectedRuntimeEnvironment: aws.RuntimeEnvironment{
398 EnvironmentIdentifier: "envName",
399 Region: "us-west-2",
400 EC2InstanceMetadataRegion: "",
401 },
402 },
403 {
404 Mode: aws.DefaultsModeAuto,
405 Env: map[string]string{
406 "AWS_EXECUTION_ENV": "envName",
407 "AWS_REGION": "us-west-2",
408 "AWS_EC2_METADATA_DISABLED": "true",
409 },
410 ExpectedDefaultsMode: aws.DefaultsModeAuto,
411 ExpectedRuntimeEnvironment: aws.RuntimeEnvironment{
412 EnvironmentIdentifier: "envName",
413 Region: "us-west-2",
414 EC2InstanceMetadataRegion: "",
415 },
416 },
417 {
418 Mode: aws.DefaultsModeAuto,
419 Env: map[string]string{
420 "AWS_REGION": "us-west-2",
421 "AWS_DEFAULT_REGION": "other",
422 "AWS_EC2_METADATA_DISABLED": "true",
423 },
424 ExpectedDefaultsMode: aws.DefaultsModeAuto,
425 ExpectedRuntimeEnvironment: aws.RuntimeEnvironment{
426 Region: "us-west-2",
427 },
428 },
429 {
430 Mode: aws.DefaultsModeAuto,
431 Env: map[string]string{
432 "AWS_DEFAULT_REGION": "us-west-2",
433 "AWS_EC2_METADATA_DISABLED": "true",
434 },
435 ExpectedDefaultsMode: aws.DefaultsModeAuto,
436 ExpectedRuntimeEnvironment: aws.RuntimeEnvironment{
437 Region: "us-west-2",
438 },
439 },
440 }
441
442 for i, tt := range cases {
443 t.Run(strconv.Itoa(i), func(t *testing.T) {
444 var server *httptest.Server
445 if tt.WithIMDS != nil {
446 server = tt.WithIMDS()
447 defer server.Close()
448 }
449 loadOptionsFunc := func(*LoadOptions) error {
450 return nil
451 }
452 if len(tt.Mode) != 0 {
453 loadOptionsFunc = WithDefaultsMode(tt.Mode, func(options *DefaultsModeOptions) {
454 if server != nil {
455 options.IMDSClient = imds.New(imds.Options{
456 Endpoint: server.URL,
457 })
458 }
459 })
460 }
461
462 if len(tt.Env) > 0 {
463 restoreEnv := awstesting.StashEnv()
464 defer awstesting.PopEnv(restoreEnv)
465
466 for key := range tt.Env {
467 _ = os.Setenv(key, tt.Env[key])
468 }
469 }
470
471 cfg, err := LoadDefaultConfig(context.Background(), loadOptionsFunc)
472 if err != nil {
473 t.Errorf("expect no error, got %v", err)
474 }
475
476 if diff := cmpDiff(tt.ExpectedDefaultsMode, cfg.DefaultsMode); len(diff) > 0 {
477 t.Errorf(diff)
478 }
479
480 if diff := cmpDiff(tt.ExpectedRuntimeEnvironment, cfg.RuntimeEnvironment); len(diff) > 0 {
481 t.Errorf(diff)
482 }
483 })
484 }
485 }
486
View as plain text