...

Source file src/github.com/aws/aws-sdk-go-v2/config/resolve_test.go

Documentation: github.com/aws/aws-sdk-go-v2/config

     1  package config
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"os"
    11  	"strconv"
    12  	"testing"
    13  
    14  	"github.com/aws/aws-sdk-go-v2/aws"
    15  	awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
    16  	"github.com/aws/aws-sdk-go-v2/credentials"
    17  	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
    18  	"github.com/aws/aws-sdk-go-v2/internal/awstesting"
    19  	"github.com/aws/aws-sdk-go-v2/internal/awstesting/unit"
    20  	"github.com/aws/smithy-go/logging"
    21  )
    22  
    23  func TestResolveCustomCABundle(t *testing.T) {
    24  	var options LoadOptions
    25  	var cfg aws.Config
    26  	cfg.HTTPClient = awshttp.NewBuildableClient()
    27  
    28  	WithCustomCABundle(bytes.NewReader(awstesting.TLSBundleCA))(&options)
    29  	configs := configs{options}
    30  
    31  	if err := resolveCustomCABundle(context.Background(), &cfg, configs); err != nil {
    32  		t.Fatalf("expect no error, got %v", err)
    33  	}
    34  
    35  	type transportGetter interface {
    36  		GetTransport() *http.Transport
    37  	}
    38  
    39  	trGetter := cfg.HTTPClient.(transportGetter)
    40  	tr := trGetter.GetTransport()
    41  	if tr.TLSClientConfig.RootCAs == nil {
    42  		t.Errorf("expect root CAs set")
    43  	}
    44  }
    45  
    46  func TestResolveCustomCABundle_ValidCA(t *testing.T) {
    47  	certFile, keyFile, caFile, err := awstesting.CreateTLSBundleFiles()
    48  	if err != nil {
    49  		t.Fatalf("failed to create cert temp files, %v", err)
    50  	}
    51  	defer func() {
    52  		awstesting.CleanupTLSBundleFiles(certFile, keyFile, caFile)
    53  	}()
    54  
    55  	serverAddr, err := awstesting.CreateTLSServer(certFile, keyFile, nil)
    56  	if err != nil {
    57  		t.Fatalf("failed to start TLS server, %v", err)
    58  	}
    59  
    60  	caPEM, err := ioutil.ReadFile(caFile)
    61  	if err != nil {
    62  		t.Fatalf("failed to read CA file, %v", err)
    63  	}
    64  
    65  	var options LoadOptions
    66  	var cfg aws.Config
    67  	cfg.HTTPClient = awshttp.NewBuildableClient()
    68  
    69  	WithCustomCABundle(bytes.NewReader(caPEM))(&options)
    70  	configs := configs{options}
    71  
    72  	if err := resolveCustomCABundle(context.Background(), &cfg, configs); err != nil {
    73  		t.Fatalf("expect no error, got %v", err)
    74  	}
    75  
    76  	req, _ := http.NewRequest("GET", serverAddr, nil)
    77  	resp, err := cfg.HTTPClient.Do(req)
    78  	if err != nil {
    79  		t.Fatalf("failed to make request to TLS server, %v", err)
    80  	}
    81  	resp.Body.Close()
    82  
    83  	if e, a := http.StatusOK, resp.StatusCode; e != a {
    84  		t.Errorf("expect %v status, got %v", e, a)
    85  	}
    86  }
    87  
    88  func TestResolveCustomCABundle_ErrorCustomClient(t *testing.T) {
    89  	var options LoadOptions
    90  	var cfg aws.Config
    91  
    92  	cfg.HTTPClient = &http.Client{}
    93  
    94  	WithCustomCABundle(bytes.NewReader(awstesting.TLSBundleCA))(&options)
    95  	configs := configs{options}
    96  
    97  	if err := resolveCustomCABundle(context.Background(), &cfg, configs); err == nil {
    98  		t.Fatalf("expect error, got none")
    99  	}
   100  }
   101  
   102  func TestResolveRegion(t *testing.T) {
   103  	var options LoadOptions
   104  	optFns := []func(options *LoadOptions) error{
   105  		WithRegion("ignored-region"),
   106  
   107  		WithRegion("mock-region"),
   108  	}
   109  
   110  	for _, optFn := range optFns {
   111  		optFn(&options)
   112  	}
   113  
   114  	configs := configs{options}
   115  
   116  	var cfg aws.Config
   117  
   118  	if err := resolveRegion(context.Background(), &cfg, configs); err != nil {
   119  		t.Fatalf("expect no error, got %v", err)
   120  	}
   121  
   122  	if e, a := "mock-region", cfg.Region; e != a {
   123  		t.Errorf("expect %v region, got %v", e, a)
   124  	}
   125  }
   126  
   127  func TestResolveAppID(t *testing.T) {
   128  	var options LoadOptions
   129  	optFns := []func(options *LoadOptions) error{
   130  		WithAppID("1234"),
   131  
   132  		WithAppID("5678"),
   133  	}
   134  
   135  	for _, optFn := range optFns {
   136  		optFn(&options)
   137  	}
   138  
   139  	configs := configs{options}
   140  
   141  	var cfg aws.Config
   142  
   143  	if err := resolveAppID(context.Background(), &cfg, configs); err != nil {
   144  		t.Fatalf("expect no error, got %v", err)
   145  	}
   146  
   147  	if e, a := "5678", cfg.AppID; e != a {
   148  		t.Errorf("expect %v app ID, got %v", e, a)
   149  	}
   150  }
   151  
   152  func TestResolveRequestMinCompressSizeBytes(t *testing.T) {
   153  	cases := map[string]struct {
   154  		RequestMinCompressSizeBytes *int64
   155  		ExpectMinBytes              int64
   156  	}{
   157  		"min requet size of 100 bytes": {
   158  			RequestMinCompressSizeBytes: aws.Int64(100),
   159  			ExpectMinBytes:              100,
   160  		},
   161  		"min request size unset": {
   162  			ExpectMinBytes: 10240,
   163  		},
   164  	}
   165  
   166  	for name, c := range cases {
   167  		t.Run(name, func(t *testing.T) {
   168  			var options LoadOptions
   169  			optFns := []func(options *LoadOptions) error{
   170  				WithRequestMinCompressSizeBytes(c.RequestMinCompressSizeBytes),
   171  			}
   172  
   173  			for _, optFn := range optFns {
   174  				optFn(&options)
   175  			}
   176  
   177  			configs := configs{options}
   178  
   179  			var cfg aws.Config
   180  
   181  			if err := resolveRequestMinCompressSizeBytes(context.Background(), &cfg, configs); err != nil {
   182  				t.Fatalf("expect no error, got %v", err)
   183  			}
   184  
   185  			if e, a := c.ExpectMinBytes, cfg.RequestMinCompressSizeBytes; e != a {
   186  				t.Errorf("expect RequestMinCompressSizeBytes to be %v , got %v", e, a)
   187  			}
   188  		})
   189  	}
   190  }
   191  
   192  func TestResolveDisableRequestCompression(t *testing.T) {
   193  	cases := map[string]struct {
   194  		DisableRequestCompression *bool
   195  		ExpectDisable             bool
   196  	}{
   197  		"disable request compression": {
   198  			DisableRequestCompression: aws.Bool(true),
   199  			ExpectDisable:             true,
   200  		},
   201  		"disable request compression unset": {
   202  			ExpectDisable: false,
   203  		},
   204  	}
   205  
   206  	for name, c := range cases {
   207  		t.Run(name, func(t *testing.T) {
   208  			var options LoadOptions
   209  			optFns := []func(options *LoadOptions) error{
   210  				WithDisableRequestCompression(c.DisableRequestCompression),
   211  			}
   212  
   213  			for _, optFn := range optFns {
   214  				optFn(&options)
   215  			}
   216  
   217  			configs := configs{options}
   218  
   219  			var cfg aws.Config
   220  
   221  			if err := resolveDisableRequestCompression(context.Background(), &cfg, configs); err != nil {
   222  				t.Fatalf("expect no error, got %v", err)
   223  			}
   224  
   225  			if e, a := c.ExpectDisable, cfg.DisableRequestCompression; e != a {
   226  				t.Errorf("expect DisableRequestCompression to be %v , got %v", e, a)
   227  			}
   228  		})
   229  	}
   230  }
   231  
   232  func TestResolveCredentialsProvider(t *testing.T) {
   233  	var options LoadOptions
   234  	optFns := []func(options *LoadOptions) error{
   235  		WithCredentialsProvider(credentials.StaticCredentialsProvider{
   236  			Value: aws.Credentials{
   237  				AccessKeyID:     "AKID",
   238  				SecretAccessKey: "SECRET",
   239  				Source:          "valid",
   240  			}},
   241  		),
   242  	}
   243  
   244  	for _, optFn := range optFns {
   245  		optFn(&options)
   246  	}
   247  
   248  	configs := configs{options}
   249  
   250  	var cfg aws.Config
   251  	cfg.Credentials = nil
   252  
   253  	if found, err := resolveCredentialProvider(context.Background(), &cfg, configs); err != nil {
   254  		t.Fatalf("expect no error, got %v", err)
   255  	} else if e, a := true, found; e != a {
   256  		t.Fatalf("expected %v, got %v", e, a)
   257  	}
   258  
   259  	_, ok := cfg.Credentials.(*aws.CredentialsCache)
   260  	if !ok {
   261  		t.Fatalf("expect resolved credentials to be wrapped in cache, was not, %T", cfg.Credentials)
   262  	}
   263  
   264  	creds, err := cfg.Credentials.Retrieve(context.Background())
   265  	if err != nil {
   266  		t.Fatalf("expect no error, got %v", err)
   267  	}
   268  
   269  	if e, a := "AKID", creds.AccessKeyID; e != a {
   270  		t.Errorf("expect %v key, got %v", e, a)
   271  	}
   272  	if e, a := "SECRET", creds.SecretAccessKey; e != a {
   273  		t.Errorf("expect %v secret, got %v", e, a)
   274  	}
   275  	if e, a := "valid", creds.Source; e != a {
   276  		t.Errorf("expect %v provider name, got %v", e, a)
   277  	}
   278  }
   279  
   280  func TestDefaultRegion(t *testing.T) {
   281  	ctx := context.Background()
   282  
   283  	var options LoadOptions
   284  	WithDefaultRegion("foo-region")(&options)
   285  
   286  	configs := configs{options}
   287  	cfg := unit.Config()
   288  
   289  	err := resolveDefaultRegion(ctx, &cfg, configs)
   290  	if err != nil {
   291  		t.Fatalf("expected no error, got %v", err)
   292  	}
   293  
   294  	if e, a := "mock-region", cfg.Region; e != a {
   295  		t.Errorf("expected %v, got %v", e, a)
   296  	}
   297  
   298  	cfg.Region = ""
   299  
   300  	err = resolveDefaultRegion(ctx, &cfg, configs)
   301  	if err != nil {
   302  		t.Fatalf("expected no error, got %v", err)
   303  	}
   304  
   305  	if e, a := "foo-region", cfg.Region; e != a {
   306  		t.Errorf("expected %v, got %v", e, a)
   307  	}
   308  }
   309  
   310  func TestResolveLogger(t *testing.T) {
   311  	cfg, err := LoadDefaultConfig(context.Background(), func(o *LoadOptions) error {
   312  		o.Logger = logging.Nop{}
   313  		return nil
   314  	})
   315  	if err != nil {
   316  		t.Fatalf("expect no error, got %v", err)
   317  	}
   318  
   319  	_, ok := cfg.Logger.(logging.Nop)
   320  	if !ok {
   321  		t.Error("unexpected logger type")
   322  	}
   323  }
   324  
   325  func TestResolveDefaultsMode(t *testing.T) {
   326  	cases := []struct {
   327  		Mode                       aws.DefaultsMode
   328  		ExpectedDefaultsMode       aws.DefaultsMode
   329  		ExpectedRuntimeEnvironment aws.RuntimeEnvironment
   330  		WithIMDS                   func() *httptest.Server
   331  		Env                        map[string]string
   332  	}{
   333  		{
   334  			ExpectedDefaultsMode: aws.DefaultsModeLegacy,
   335  		},
   336  		{
   337  			Mode:                 aws.DefaultsModeStandard,
   338  			ExpectedDefaultsMode: aws.DefaultsModeStandard,
   339  		},
   340  		{
   341  			Mode:                 aws.DefaultsModeInRegion,
   342  			ExpectedDefaultsMode: aws.DefaultsModeInRegion,
   343  		},
   344  		{
   345  			Mode:                 aws.DefaultsModeCrossRegion,
   346  			ExpectedDefaultsMode: aws.DefaultsModeCrossRegion,
   347  		},
   348  		{
   349  			Mode:                 aws.DefaultsModeMobile,
   350  			ExpectedDefaultsMode: aws.DefaultsModeMobile,
   351  		},
   352  		{
   353  			Mode: aws.DefaultsModeAuto,
   354  			Env: map[string]string{
   355  				"AWS_EXECUTION_ENV": "envName",
   356  				"AWS_REGION":        "us-west-2",
   357  			},
   358  			WithIMDS: func() *httptest.Server {
   359  				return httptest.NewServer(http.HandlerFunc(
   360  					func(w http.ResponseWriter, r *http.Request) {
   361  						if r.URL.Path == "/latest/dynamic/instance-identity/document" {
   362  							out, _ := json.Marshal(&imds.InstanceIdentityDocument{
   363  								Region: "us-west-2",
   364  							})
   365  							w.Write(out)
   366  						} else if r.URL.Path == "/latest/api/token" {
   367  							header := w.Header()
   368  							// bounce the TTL header
   369  							const ttlHeader = "X-Aws-Ec2-Metadata-Token-Ttl-Seconds"
   370  							header.Set(ttlHeader, r.Header.Get(ttlHeader))
   371  							w.Write([]byte("validToken"))
   372  						} else {
   373  							w.Write([]byte(""))
   374  						}
   375  					}))
   376  			},
   377  			ExpectedDefaultsMode: aws.DefaultsModeAuto,
   378  			ExpectedRuntimeEnvironment: aws.RuntimeEnvironment{
   379  				EnvironmentIdentifier:     "envName",
   380  				Region:                    "us-west-2",
   381  				EC2InstanceMetadataRegion: "us-west-2",
   382  			},
   383  		},
   384  		{
   385  			Mode: aws.DefaultsModeAuto,
   386  			Env: map[string]string{
   387  				"AWS_EXECUTION_ENV": "envName",
   388  				"AWS_REGION":        "us-west-2",
   389  			},
   390  			WithIMDS: func() *httptest.Server {
   391  				return httptest.NewServer(http.HandlerFunc(
   392  					func(w http.ResponseWriter, r *http.Request) {
   393  						w.WriteHeader(500)
   394  					}))
   395  			},
   396  			ExpectedDefaultsMode: aws.DefaultsModeAuto,
   397  			ExpectedRuntimeEnvironment: aws.RuntimeEnvironment{
   398  				EnvironmentIdentifier:     "envName",
   399  				Region:                    "us-west-2",
   400  				EC2InstanceMetadataRegion: "",
   401  			},
   402  		},
   403  		{
   404  			Mode: aws.DefaultsModeAuto,
   405  			Env: map[string]string{
   406  				"AWS_EXECUTION_ENV":         "envName",
   407  				"AWS_REGION":                "us-west-2",
   408  				"AWS_EC2_METADATA_DISABLED": "true",
   409  			},
   410  			ExpectedDefaultsMode: aws.DefaultsModeAuto,
   411  			ExpectedRuntimeEnvironment: aws.RuntimeEnvironment{
   412  				EnvironmentIdentifier:     "envName",
   413  				Region:                    "us-west-2",
   414  				EC2InstanceMetadataRegion: "",
   415  			},
   416  		},
   417  		{
   418  			Mode: aws.DefaultsModeAuto,
   419  			Env: map[string]string{
   420  				"AWS_REGION":                "us-west-2",
   421  				"AWS_DEFAULT_REGION":        "other",
   422  				"AWS_EC2_METADATA_DISABLED": "true",
   423  			},
   424  			ExpectedDefaultsMode: aws.DefaultsModeAuto,
   425  			ExpectedRuntimeEnvironment: aws.RuntimeEnvironment{
   426  				Region: "us-west-2",
   427  			},
   428  		},
   429  		{
   430  			Mode: aws.DefaultsModeAuto,
   431  			Env: map[string]string{
   432  				"AWS_DEFAULT_REGION":        "us-west-2",
   433  				"AWS_EC2_METADATA_DISABLED": "true",
   434  			},
   435  			ExpectedDefaultsMode: aws.DefaultsModeAuto,
   436  			ExpectedRuntimeEnvironment: aws.RuntimeEnvironment{
   437  				Region: "us-west-2",
   438  			},
   439  		},
   440  	}
   441  
   442  	for i, tt := range cases {
   443  		t.Run(strconv.Itoa(i), func(t *testing.T) {
   444  			var server *httptest.Server
   445  			if tt.WithIMDS != nil {
   446  				server = tt.WithIMDS()
   447  				defer server.Close()
   448  			}
   449  			loadOptionsFunc := func(*LoadOptions) error {
   450  				return nil
   451  			}
   452  			if len(tt.Mode) != 0 {
   453  				loadOptionsFunc = WithDefaultsMode(tt.Mode, func(options *DefaultsModeOptions) {
   454  					if server != nil {
   455  						options.IMDSClient = imds.New(imds.Options{
   456  							Endpoint: server.URL,
   457  						})
   458  					}
   459  				})
   460  			}
   461  
   462  			if len(tt.Env) > 0 {
   463  				restoreEnv := awstesting.StashEnv()
   464  				defer awstesting.PopEnv(restoreEnv)
   465  
   466  				for key := range tt.Env {
   467  					_ = os.Setenv(key, tt.Env[key])
   468  				}
   469  			}
   470  
   471  			cfg, err := LoadDefaultConfig(context.Background(), loadOptionsFunc)
   472  			if err != nil {
   473  				t.Errorf("expect no error, got %v", err)
   474  			}
   475  
   476  			if diff := cmpDiff(tt.ExpectedDefaultsMode, cfg.DefaultsMode); len(diff) > 0 {
   477  				t.Errorf(diff)
   478  			}
   479  
   480  			if diff := cmpDiff(tt.ExpectedRuntimeEnvironment, cfg.RuntimeEnvironment); len(diff) > 0 {
   481  				t.Errorf(diff)
   482  			}
   483  		})
   484  	}
   485  }
   486  

View as plain text