1 // Copyright (C) MongoDB, Inc. 2023-present. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may 4 // not use this file except in compliance with the License. You may obtain 5 // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 // 7 // Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from: 8 // - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials.go 9 // See THIRD-PARTY-NOTICES for original license terms 10 11 package credentials 12 13 import ( 14 "context" 15 "sync" 16 "time" 17 18 "go.mongodb.org/mongo-driver/internal/aws/awserr" 19 "golang.org/x/sync/singleflight" 20 ) 21 22 // A Value is the AWS credentials value for individual credential fields. 23 // 24 // A Value is also used to represent Azure credentials. 25 // Azure credentials only consist of an access token, which is stored in the `SessionToken` field. 26 type Value struct { 27 // AWS Access key ID 28 AccessKeyID string 29 30 // AWS Secret Access Key 31 SecretAccessKey string 32 33 // AWS Session Token 34 SessionToken string 35 36 // Provider used to get credentials 37 ProviderName string 38 } 39 40 // HasKeys returns if the credentials Value has both AccessKeyID and 41 // SecretAccessKey value set. 42 func (v Value) HasKeys() bool { 43 return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0 44 } 45 46 // A Provider is the interface for any component which will provide credentials 47 // Value. A provider is required to manage its own Expired state, and what to 48 // be expired means. 49 // 50 // The Provider should not need to implement its own mutexes, because 51 // that will be managed by Credentials. 52 type Provider interface { 53 // Retrieve returns nil if it successfully retrieved the value. 54 // Error is returned if the value were not obtainable, or empty. 55 Retrieve() (Value, error) 56 57 // IsExpired returns if the credentials are no longer valid, and need 58 // to be retrieved. 59 IsExpired() bool 60 } 61 62 // ProviderWithContext is a Provider that can retrieve credentials with a Context 63 type ProviderWithContext interface { 64 Provider 65 66 RetrieveWithContext(context.Context) (Value, error) 67 } 68 69 // A Credentials provides concurrency safe retrieval of AWS credentials Value. 70 // 71 // A Credentials is also used to fetch Azure credentials Value. 72 // 73 // Credentials will cache the credentials value until they expire. Once the value 74 // expires the next Get will attempt to retrieve valid credentials. 75 // 76 // Credentials is safe to use across multiple goroutines and will manage the 77 // synchronous state so the Providers do not need to implement their own 78 // synchronization. 79 // 80 // The first Credentials.Get() will always call Provider.Retrieve() to get the 81 // first instance of the credentials Value. All calls to Get() after that 82 // will return the cached credentials Value until IsExpired() returns true. 83 type Credentials struct { 84 sf singleflight.Group 85 86 m sync.RWMutex 87 creds Value 88 provider Provider 89 } 90 91 // NewCredentials returns a pointer to a new Credentials with the provider set. 92 func NewCredentials(provider Provider) *Credentials { 93 c := &Credentials{ 94 provider: provider, 95 } 96 return c 97 } 98 99 // GetWithContext returns the credentials value, or error if the credentials 100 // Value failed to be retrieved. Will return early if the passed in context is 101 // canceled. 102 // 103 // Will return the cached credentials Value if it has not expired. If the 104 // credentials Value has expired the Provider's Retrieve() will be called 105 // to refresh the credentials. 106 // 107 // If Credentials.Expire() was called the credentials Value will be force 108 // expired, and the next call to Get() will cause them to be refreshed. 109 func (c *Credentials) GetWithContext(ctx context.Context) (Value, error) { 110 // Check if credentials are cached, and not expired. 111 select { 112 case curCreds, ok := <-c.asyncIsExpired(): 113 // ok will only be true, of the credentials were not expired. ok will 114 // be false and have no value if the credentials are expired. 115 if ok { 116 return curCreds, nil 117 } 118 case <-ctx.Done(): 119 return Value{}, awserr.New("RequestCanceled", 120 "request context canceled", ctx.Err()) 121 } 122 123 // Cannot pass context down to the actual retrieve, because the first 124 // context would cancel the whole group when there is not direct 125 // association of items in the group. 126 resCh := c.sf.DoChan("", func() (interface{}, error) { 127 return c.singleRetrieve(&suppressedContext{ctx}) 128 }) 129 select { 130 case res := <-resCh: 131 return res.Val.(Value), res.Err 132 case <-ctx.Done(): 133 return Value{}, awserr.New("RequestCanceled", 134 "request context canceled", ctx.Err()) 135 } 136 } 137 138 func (c *Credentials) singleRetrieve(ctx context.Context) (interface{}, error) { 139 c.m.Lock() 140 defer c.m.Unlock() 141 142 if curCreds := c.creds; !c.isExpiredLocked(curCreds) { 143 return curCreds, nil 144 } 145 146 var creds Value 147 var err error 148 if p, ok := c.provider.(ProviderWithContext); ok { 149 creds, err = p.RetrieveWithContext(ctx) 150 } else { 151 creds, err = c.provider.Retrieve() 152 } 153 if err == nil { 154 c.creds = creds 155 } 156 157 return creds, err 158 } 159 160 // asyncIsExpired returns a channel of credentials Value. If the channel is 161 // closed the credentials are expired and credentials value are not empty. 162 func (c *Credentials) asyncIsExpired() <-chan Value { 163 ch := make(chan Value, 1) 164 go func() { 165 c.m.RLock() 166 defer c.m.RUnlock() 167 168 if curCreds := c.creds; !c.isExpiredLocked(curCreds) { 169 ch <- curCreds 170 } 171 172 close(ch) 173 }() 174 175 return ch 176 } 177 178 // isExpiredLocked helper method wrapping the definition of expired credentials. 179 func (c *Credentials) isExpiredLocked(creds interface{}) bool { 180 return creds == nil || creds.(Value) == Value{} || c.provider.IsExpired() 181 } 182 183 type suppressedContext struct { 184 context.Context 185 } 186 187 func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) { 188 return time.Time{}, false 189 } 190 191 func (s *suppressedContext) Done() <-chan struct{} { 192 return nil 193 } 194 195 func (s *suppressedContext) Err() error { 196 return nil 197 } 198