...

Source file src/go.mongodb.org/mongo-driver/internal/aws/credentials/chain_provider_test.go

Documentation: go.mongodb.org/mongo-driver/internal/aws/credentials

     1  // Copyright (C) MongoDB, Inc. 2023-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  //
     7  // Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
     8  // - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/chain_provider_test.go
     9  // See THIRD-PARTY-NOTICES for original license terms
    10  
    11  package credentials
    12  
    13  import (
    14  	"reflect"
    15  	"testing"
    16  
    17  	"go.mongodb.org/mongo-driver/internal/aws/awserr"
    18  )
    19  
    20  type secondStubProvider struct {
    21  	creds   Value
    22  	expired bool
    23  	err     error
    24  }
    25  
    26  func (s *secondStubProvider) Retrieve() (Value, error) {
    27  	s.expired = false
    28  	s.creds.ProviderName = "secondStubProvider"
    29  	return s.creds, s.err
    30  }
    31  func (s *secondStubProvider) IsExpired() bool {
    32  	return s.expired
    33  }
    34  
    35  func TestChainProviderWithNames(t *testing.T) {
    36  	p := &ChainProvider{
    37  		Providers: []Provider{
    38  			&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
    39  			&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
    40  			&secondStubProvider{
    41  				creds: Value{
    42  					AccessKeyID:     "AKIF",
    43  					SecretAccessKey: "NOSECRET",
    44  					SessionToken:    "",
    45  				},
    46  			},
    47  			&stubProvider{
    48  				creds: Value{
    49  					AccessKeyID:     "AKID",
    50  					SecretAccessKey: "SECRET",
    51  					SessionToken:    "",
    52  				},
    53  			},
    54  		},
    55  	}
    56  
    57  	creds, err := p.Retrieve()
    58  	if err != nil {
    59  		t.Errorf("Expect no error, got %v", err)
    60  	}
    61  	if e, a := "secondStubProvider", creds.ProviderName; e != a {
    62  		t.Errorf("Expect provider name to match, %v got, %v", e, a)
    63  	}
    64  
    65  	// Also check credentials
    66  	if e, a := "AKIF", creds.AccessKeyID; e != a {
    67  		t.Errorf("Expect access key ID to match, %v got %v", e, a)
    68  	}
    69  	if e, a := "NOSECRET", creds.SecretAccessKey; e != a {
    70  		t.Errorf("Expect secret access key to match, %v got %v", e, a)
    71  	}
    72  	if v := creds.SessionToken; len(v) != 0 {
    73  		t.Errorf("Expect session token to be empty, %v", v)
    74  	}
    75  
    76  }
    77  
    78  func TestChainProviderGet(t *testing.T) {
    79  	p := &ChainProvider{
    80  		Providers: []Provider{
    81  			&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
    82  			&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
    83  			&stubProvider{
    84  				creds: Value{
    85  					AccessKeyID:     "AKID",
    86  					SecretAccessKey: "SECRET",
    87  					SessionToken:    "",
    88  				},
    89  			},
    90  		},
    91  	}
    92  
    93  	creds, err := p.Retrieve()
    94  	if err != nil {
    95  		t.Errorf("Expect no error, got %v", err)
    96  	}
    97  	if e, a := "AKID", creds.AccessKeyID; e != a {
    98  		t.Errorf("Expect access key ID to match, %v got %v", e, a)
    99  	}
   100  	if e, a := "SECRET", creds.SecretAccessKey; e != a {
   101  		t.Errorf("Expect secret access key to match, %v got %v", e, a)
   102  	}
   103  	if v := creds.SessionToken; len(v) != 0 {
   104  		t.Errorf("Expect session token to be empty, %v", v)
   105  	}
   106  }
   107  
   108  func TestChainProviderIsExpired(t *testing.T) {
   109  	stubProvider := &stubProvider{expired: true}
   110  	p := &ChainProvider{
   111  		Providers: []Provider{
   112  			stubProvider,
   113  		},
   114  	}
   115  
   116  	if !p.IsExpired() {
   117  		t.Errorf("Expect expired to be true before any Retrieve")
   118  	}
   119  	_, err := p.Retrieve()
   120  	if err != nil {
   121  		t.Errorf("Expect no error, got %v", err)
   122  	}
   123  	if p.IsExpired() {
   124  		t.Errorf("Expect not expired after retrieve")
   125  	}
   126  
   127  	stubProvider.expired = true
   128  	if !p.IsExpired() {
   129  		t.Errorf("Expect return of expired provider")
   130  	}
   131  
   132  	_, err = p.Retrieve()
   133  	if err != nil {
   134  		t.Errorf("Expect no error, got %v", err)
   135  	}
   136  	if p.IsExpired() {
   137  		t.Errorf("Expect not expired after retrieve")
   138  	}
   139  }
   140  
   141  func TestChainProviderWithNoProvider(t *testing.T) {
   142  	p := &ChainProvider{
   143  		Providers: []Provider{},
   144  	}
   145  
   146  	if !p.IsExpired() {
   147  		t.Errorf("Expect expired with no providers")
   148  	}
   149  	_, err := p.Retrieve()
   150  	if err.Error() != "NoCredentialProviders: no valid providers in chain" {
   151  		t.Errorf("Expect no providers error returned, got %v", err)
   152  	}
   153  }
   154  
   155  func TestChainProviderWithNoValidProvider(t *testing.T) {
   156  	errs := []error{
   157  		awserr.New("FirstError", "first provider error", nil),
   158  		awserr.New("SecondError", "second provider error", nil),
   159  	}
   160  	p := &ChainProvider{
   161  		Providers: []Provider{
   162  			&stubProvider{err: errs[0]},
   163  			&stubProvider{err: errs[1]},
   164  		},
   165  	}
   166  
   167  	if !p.IsExpired() {
   168  		t.Errorf("Expect expired with no providers")
   169  	}
   170  	_, err := p.Retrieve()
   171  
   172  	expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
   173  	if e, a := expectErr, err; !reflect.DeepEqual(e, a) {
   174  		t.Errorf("Expect no providers error returned, %v, got %v", e, a)
   175  	}
   176  }
   177  

View as plain text