// Copyright (C) MongoDB, Inc. 2023-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 // // Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from: // - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/chain_provider_test.go // See THIRD-PARTY-NOTICES for original license terms package credentials import ( "reflect" "testing" "go.mongodb.org/mongo-driver/internal/aws/awserr" ) type secondStubProvider struct { creds Value expired bool err error } func (s *secondStubProvider) Retrieve() (Value, error) { s.expired = false s.creds.ProviderName = "secondStubProvider" return s.creds, s.err } func (s *secondStubProvider) IsExpired() bool { return s.expired } func TestChainProviderWithNames(t *testing.T) { p := &ChainProvider{ Providers: []Provider{ &stubProvider{err: awserr.New("FirstError", "first provider error", nil)}, &stubProvider{err: awserr.New("SecondError", "second provider error", nil)}, &secondStubProvider{ creds: Value{ AccessKeyID: "AKIF", SecretAccessKey: "NOSECRET", SessionToken: "", }, }, &stubProvider{ creds: Value{ AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "", }, }, }, } creds, err := p.Retrieve() if err != nil { t.Errorf("Expect no error, got %v", err) } if e, a := "secondStubProvider", creds.ProviderName; e != a { t.Errorf("Expect provider name to match, %v got, %v", e, a) } // Also check credentials if e, a := "AKIF", creds.AccessKeyID; e != a { t.Errorf("Expect access key ID to match, %v got %v", e, a) } if e, a := "NOSECRET", creds.SecretAccessKey; e != a { t.Errorf("Expect secret access key to match, %v got %v", e, a) } if v := creds.SessionToken; len(v) != 0 { t.Errorf("Expect session token to be empty, %v", v) } } func TestChainProviderGet(t *testing.T) { p := &ChainProvider{ Providers: []Provider{ &stubProvider{err: awserr.New("FirstError", "first provider error", nil)}, &stubProvider{err: awserr.New("SecondError", "second provider error", nil)}, &stubProvider{ creds: Value{ AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "", }, }, }, } creds, err := p.Retrieve() if err != nil { t.Errorf("Expect no error, got %v", err) } if e, a := "AKID", creds.AccessKeyID; e != a { t.Errorf("Expect access key ID to match, %v got %v", e, a) } if e, a := "SECRET", creds.SecretAccessKey; e != a { t.Errorf("Expect secret access key to match, %v got %v", e, a) } if v := creds.SessionToken; len(v) != 0 { t.Errorf("Expect session token to be empty, %v", v) } } func TestChainProviderIsExpired(t *testing.T) { stubProvider := &stubProvider{expired: true} p := &ChainProvider{ Providers: []Provider{ stubProvider, }, } if !p.IsExpired() { t.Errorf("Expect expired to be true before any Retrieve") } _, err := p.Retrieve() if err != nil { t.Errorf("Expect no error, got %v", err) } if p.IsExpired() { t.Errorf("Expect not expired after retrieve") } stubProvider.expired = true if !p.IsExpired() { t.Errorf("Expect return of expired provider") } _, err = p.Retrieve() if err != nil { t.Errorf("Expect no error, got %v", err) } if p.IsExpired() { t.Errorf("Expect not expired after retrieve") } } func TestChainProviderWithNoProvider(t *testing.T) { p := &ChainProvider{ Providers: []Provider{}, } if !p.IsExpired() { t.Errorf("Expect expired with no providers") } _, err := p.Retrieve() if err.Error() != "NoCredentialProviders: no valid providers in chain" { t.Errorf("Expect no providers error returned, got %v", err) } } func TestChainProviderWithNoValidProvider(t *testing.T) { errs := []error{ awserr.New("FirstError", "first provider error", nil), awserr.New("SecondError", "second provider error", nil), } p := &ChainProvider{ Providers: []Provider{ &stubProvider{err: errs[0]}, &stubProvider{err: errs[1]}, }, } if !p.IsExpired() { t.Errorf("Expect expired with no providers") } _, err := p.Retrieve() expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs) if e, a := expectErr, err; !reflect.DeepEqual(e, a) { t.Errorf("Expect no providers error returned, %v, got %v", e, a) } }