...

Source file src/go.mongodb.org/mongo-driver/internal/aws/credentials/credentials.go

Documentation: go.mongodb.org/mongo-driver/internal/aws/credentials

     1  // Copyright (C) MongoDB, Inc. 2023-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  //
     7  // Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
     8  // - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials.go
     9  // See THIRD-PARTY-NOTICES for original license terms
    10  
    11  package credentials
    12  
    13  import (
    14  	"context"
    15  	"sync"
    16  	"time"
    17  
    18  	"go.mongodb.org/mongo-driver/internal/aws/awserr"
    19  	"golang.org/x/sync/singleflight"
    20  )
    21  
    22  // A Value is the AWS credentials value for individual credential fields.
    23  //
    24  // A Value is also used to represent Azure credentials.
    25  // Azure credentials only consist of an access token, which is stored in the `SessionToken` field.
    26  type Value struct {
    27  	// AWS Access key ID
    28  	AccessKeyID string
    29  
    30  	// AWS Secret Access Key
    31  	SecretAccessKey string
    32  
    33  	// AWS Session Token
    34  	SessionToken string
    35  
    36  	// Provider used to get credentials
    37  	ProviderName string
    38  }
    39  
    40  // HasKeys returns if the credentials Value has both AccessKeyID and
    41  // SecretAccessKey value set.
    42  func (v Value) HasKeys() bool {
    43  	return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
    44  }
    45  
    46  // A Provider is the interface for any component which will provide credentials
    47  // Value. A provider is required to manage its own Expired state, and what to
    48  // be expired means.
    49  //
    50  // The Provider should not need to implement its own mutexes, because
    51  // that will be managed by Credentials.
    52  type Provider interface {
    53  	// Retrieve returns nil if it successfully retrieved the value.
    54  	// Error is returned if the value were not obtainable, or empty.
    55  	Retrieve() (Value, error)
    56  
    57  	// IsExpired returns if the credentials are no longer valid, and need
    58  	// to be retrieved.
    59  	IsExpired() bool
    60  }
    61  
    62  // ProviderWithContext is a Provider that can retrieve credentials with a Context
    63  type ProviderWithContext interface {
    64  	Provider
    65  
    66  	RetrieveWithContext(context.Context) (Value, error)
    67  }
    68  
    69  // A Credentials provides concurrency safe retrieval of AWS credentials Value.
    70  //
    71  // A Credentials is also used to fetch Azure credentials Value.
    72  //
    73  // Credentials will cache the credentials value until they expire. Once the value
    74  // expires the next Get will attempt to retrieve valid credentials.
    75  //
    76  // Credentials is safe to use across multiple goroutines and will manage the
    77  // synchronous state so the Providers do not need to implement their own
    78  // synchronization.
    79  //
    80  // The first Credentials.Get() will always call Provider.Retrieve() to get the
    81  // first instance of the credentials Value. All calls to Get() after that
    82  // will return the cached credentials Value until IsExpired() returns true.
    83  type Credentials struct {
    84  	sf singleflight.Group
    85  
    86  	m        sync.RWMutex
    87  	creds    Value
    88  	provider Provider
    89  }
    90  
    91  // NewCredentials returns a pointer to a new Credentials with the provider set.
    92  func NewCredentials(provider Provider) *Credentials {
    93  	c := &Credentials{
    94  		provider: provider,
    95  	}
    96  	return c
    97  }
    98  
    99  // GetWithContext returns the credentials value, or error if the credentials
   100  // Value failed to be retrieved. Will return early if the passed in context is
   101  // canceled.
   102  //
   103  // Will return the cached credentials Value if it has not expired. If the
   104  // credentials Value has expired the Provider's Retrieve() will be called
   105  // to refresh the credentials.
   106  //
   107  // If Credentials.Expire() was called the credentials Value will be force
   108  // expired, and the next call to Get() will cause them to be refreshed.
   109  func (c *Credentials) GetWithContext(ctx context.Context) (Value, error) {
   110  	// Check if credentials are cached, and not expired.
   111  	select {
   112  	case curCreds, ok := <-c.asyncIsExpired():
   113  		// ok will only be true, of the credentials were not expired. ok will
   114  		// be false and have no value if the credentials are expired.
   115  		if ok {
   116  			return curCreds, nil
   117  		}
   118  	case <-ctx.Done():
   119  		return Value{}, awserr.New("RequestCanceled",
   120  			"request context canceled", ctx.Err())
   121  	}
   122  
   123  	// Cannot pass context down to the actual retrieve, because the first
   124  	// context would cancel the whole group when there is not direct
   125  	// association of items in the group.
   126  	resCh := c.sf.DoChan("", func() (interface{}, error) {
   127  		return c.singleRetrieve(&suppressedContext{ctx})
   128  	})
   129  	select {
   130  	case res := <-resCh:
   131  		return res.Val.(Value), res.Err
   132  	case <-ctx.Done():
   133  		return Value{}, awserr.New("RequestCanceled",
   134  			"request context canceled", ctx.Err())
   135  	}
   136  }
   137  
   138  func (c *Credentials) singleRetrieve(ctx context.Context) (interface{}, error) {
   139  	c.m.Lock()
   140  	defer c.m.Unlock()
   141  
   142  	if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
   143  		return curCreds, nil
   144  	}
   145  
   146  	var creds Value
   147  	var err error
   148  	if p, ok := c.provider.(ProviderWithContext); ok {
   149  		creds, err = p.RetrieveWithContext(ctx)
   150  	} else {
   151  		creds, err = c.provider.Retrieve()
   152  	}
   153  	if err == nil {
   154  		c.creds = creds
   155  	}
   156  
   157  	return creds, err
   158  }
   159  
   160  // asyncIsExpired returns a channel of credentials Value. If the channel is
   161  // closed the credentials are expired and credentials value are not empty.
   162  func (c *Credentials) asyncIsExpired() <-chan Value {
   163  	ch := make(chan Value, 1)
   164  	go func() {
   165  		c.m.RLock()
   166  		defer c.m.RUnlock()
   167  
   168  		if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
   169  			ch <- curCreds
   170  		}
   171  
   172  		close(ch)
   173  	}()
   174  
   175  	return ch
   176  }
   177  
   178  // isExpiredLocked helper method wrapping the definition of expired credentials.
   179  func (c *Credentials) isExpiredLocked(creds interface{}) bool {
   180  	return creds == nil || creds.(Value) == Value{} || c.provider.IsExpired()
   181  }
   182  
   183  type suppressedContext struct {
   184  	context.Context
   185  }
   186  
   187  func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
   188  	return time.Time{}, false
   189  }
   190  
   191  func (s *suppressedContext) Done() <-chan struct{} {
   192  	return nil
   193  }
   194  
   195  func (s *suppressedContext) Err() error {
   196  	return nil
   197  }
   198  

View as plain text