...

Source file src/go.mongodb.org/mongo-driver/internal/aws/credentials/credentials_test.go

Documentation: go.mongodb.org/mongo-driver/internal/aws/credentials

     1  // Copyright (C) MongoDB, Inc. 2023-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  //
     7  // Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
     8  // - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials_test.go
     9  // See THIRD-PARTY-NOTICES for original license terms
    10  
    11  package credentials
    12  
    13  import (
    14  	"context"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  
    19  	"go.mongodb.org/mongo-driver/internal/aws/awserr"
    20  )
    21  
    22  func isExpired(c *Credentials) bool {
    23  	c.m.RLock()
    24  	defer c.m.RUnlock()
    25  
    26  	return c.isExpiredLocked(c.creds)
    27  }
    28  
    29  type stubProvider struct {
    30  	creds          Value
    31  	retrievedCount int
    32  	expired        bool
    33  	err            error
    34  }
    35  
    36  func (s *stubProvider) Retrieve() (Value, error) {
    37  	s.retrievedCount++
    38  	s.expired = false
    39  	s.creds.ProviderName = "stubProvider"
    40  	return s.creds, s.err
    41  }
    42  func (s *stubProvider) IsExpired() bool {
    43  	return s.expired
    44  }
    45  
    46  func TestCredentialsGet(t *testing.T) {
    47  	c := NewCredentials(&stubProvider{
    48  		creds: Value{
    49  			AccessKeyID:     "AKID",
    50  			SecretAccessKey: "SECRET",
    51  			SessionToken:    "",
    52  		},
    53  		expired: true,
    54  	})
    55  
    56  	creds, err := c.GetWithContext(context.Background())
    57  	if err != nil {
    58  		t.Errorf("Expected no error, got %v", err)
    59  	}
    60  	if e, a := "AKID", creds.AccessKeyID; e != a {
    61  		t.Errorf("Expect access key ID to match, %v got %v", e, a)
    62  	}
    63  	if e, a := "SECRET", creds.SecretAccessKey; e != a {
    64  		t.Errorf("Expect secret access key to match, %v got %v", e, a)
    65  	}
    66  	if v := creds.SessionToken; len(v) != 0 {
    67  		t.Errorf("Expect session token to be empty, %v", v)
    68  	}
    69  }
    70  
    71  func TestCredentialsGetWithError(t *testing.T) {
    72  	c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
    73  
    74  	_, err := c.GetWithContext(context.Background())
    75  	if e, a := "provider error", err.(awserr.Error).Code(); e != a {
    76  		t.Errorf("Expected provider error, %v got %v", e, a)
    77  	}
    78  }
    79  
    80  func TestCredentialsExpire(t *testing.T) {
    81  	stub := &stubProvider{}
    82  	c := NewCredentials(stub)
    83  
    84  	stub.expired = false
    85  	if !isExpired(c) {
    86  		t.Errorf("Expected to start out expired")
    87  	}
    88  
    89  	_, err := c.GetWithContext(context.Background())
    90  	if err != nil {
    91  		t.Errorf("Expected no err, got %v", err)
    92  	}
    93  	if isExpired(c) {
    94  		t.Errorf("Expected not to be expired")
    95  	}
    96  
    97  	stub.expired = true
    98  	if !isExpired(c) {
    99  		t.Errorf("Expected to be expired")
   100  	}
   101  }
   102  
   103  func TestCredentialsGetWithProviderName(t *testing.T) {
   104  	stub := &stubProvider{}
   105  
   106  	c := NewCredentials(stub)
   107  
   108  	creds, err := c.GetWithContext(context.Background())
   109  	if err != nil {
   110  		t.Errorf("Expected no error, got %v", err)
   111  	}
   112  	if e, a := creds.ProviderName, "stubProvider"; e != a {
   113  		t.Errorf("Expected provider name to match, %v got %v", e, a)
   114  	}
   115  }
   116  
   117  type MockProvider struct {
   118  	// The date/time when to expire on
   119  	expiration time.Time
   120  
   121  	// If set will be used by IsExpired to determine the current time.
   122  	// Defaults to time.Now if CurrentTime is not set.  Available for testing
   123  	// to be able to mock out the current time.
   124  	CurrentTime func() time.Time
   125  }
   126  
   127  // IsExpired returns if the credentials are expired.
   128  func (e *MockProvider) IsExpired() bool {
   129  	curTime := e.CurrentTime
   130  	if curTime == nil {
   131  		curTime = time.Now
   132  	}
   133  	return e.expiration.Before(curTime())
   134  }
   135  
   136  func (*MockProvider) Retrieve() (Value, error) {
   137  	return Value{}, nil
   138  }
   139  
   140  func TestCredentialsIsExpired_Race(_ *testing.T) {
   141  	creds := NewChainCredentials([]Provider{&MockProvider{}})
   142  
   143  	starter := make(chan struct{})
   144  	var wg sync.WaitGroup
   145  	wg.Add(10)
   146  	for i := 0; i < 10; i++ {
   147  		go func() {
   148  			defer wg.Done()
   149  			<-starter
   150  			for i := 0; i < 100; i++ {
   151  				isExpired(creds)
   152  			}
   153  		}()
   154  	}
   155  	close(starter)
   156  
   157  	wg.Wait()
   158  }
   159  
   160  type stubProviderConcurrent struct {
   161  	stubProvider
   162  	done chan struct{}
   163  }
   164  
   165  func (s *stubProviderConcurrent) Retrieve() (Value, error) {
   166  	<-s.done
   167  	return s.stubProvider.Retrieve()
   168  }
   169  
   170  func TestCredentialsGetConcurrent(t *testing.T) {
   171  	stub := &stubProviderConcurrent{
   172  		done: make(chan struct{}),
   173  	}
   174  
   175  	c := NewCredentials(stub)
   176  	done := make(chan struct{})
   177  
   178  	for i := 0; i < 2; i++ {
   179  		go func() {
   180  			_, err := c.GetWithContext(context.Background())
   181  			if err != nil {
   182  				t.Errorf("Expected no err, got %v", err)
   183  			}
   184  			done <- struct{}{}
   185  		}()
   186  	}
   187  
   188  	// Validates that a single call to Retrieve is shared between two calls to Get
   189  	stub.done <- struct{}{}
   190  	<-done
   191  	<-done
   192  }
   193  

View as plain text