1
2
3
4
5
6
7
8
9
10
11 package credentials
12
13 import (
14 "reflect"
15 "testing"
16
17 "go.mongodb.org/mongo-driver/internal/aws/awserr"
18 )
19
20 type secondStubProvider struct {
21 creds Value
22 expired bool
23 err error
24 }
25
26 func (s *secondStubProvider) Retrieve() (Value, error) {
27 s.expired = false
28 s.creds.ProviderName = "secondStubProvider"
29 return s.creds, s.err
30 }
31 func (s *secondStubProvider) IsExpired() bool {
32 return s.expired
33 }
34
35 func TestChainProviderWithNames(t *testing.T) {
36 p := &ChainProvider{
37 Providers: []Provider{
38 &stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
39 &stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
40 &secondStubProvider{
41 creds: Value{
42 AccessKeyID: "AKIF",
43 SecretAccessKey: "NOSECRET",
44 SessionToken: "",
45 },
46 },
47 &stubProvider{
48 creds: Value{
49 AccessKeyID: "AKID",
50 SecretAccessKey: "SECRET",
51 SessionToken: "",
52 },
53 },
54 },
55 }
56
57 creds, err := p.Retrieve()
58 if err != nil {
59 t.Errorf("Expect no error, got %v", err)
60 }
61 if e, a := "secondStubProvider", creds.ProviderName; e != a {
62 t.Errorf("Expect provider name to match, %v got, %v", e, a)
63 }
64
65
66 if e, a := "AKIF", creds.AccessKeyID; e != a {
67 t.Errorf("Expect access key ID to match, %v got %v", e, a)
68 }
69 if e, a := "NOSECRET", creds.SecretAccessKey; e != a {
70 t.Errorf("Expect secret access key to match, %v got %v", e, a)
71 }
72 if v := creds.SessionToken; len(v) != 0 {
73 t.Errorf("Expect session token to be empty, %v", v)
74 }
75
76 }
77
78 func TestChainProviderGet(t *testing.T) {
79 p := &ChainProvider{
80 Providers: []Provider{
81 &stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
82 &stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
83 &stubProvider{
84 creds: Value{
85 AccessKeyID: "AKID",
86 SecretAccessKey: "SECRET",
87 SessionToken: "",
88 },
89 },
90 },
91 }
92
93 creds, err := p.Retrieve()
94 if err != nil {
95 t.Errorf("Expect no error, got %v", err)
96 }
97 if e, a := "AKID", creds.AccessKeyID; e != a {
98 t.Errorf("Expect access key ID to match, %v got %v", e, a)
99 }
100 if e, a := "SECRET", creds.SecretAccessKey; e != a {
101 t.Errorf("Expect secret access key to match, %v got %v", e, a)
102 }
103 if v := creds.SessionToken; len(v) != 0 {
104 t.Errorf("Expect session token to be empty, %v", v)
105 }
106 }
107
108 func TestChainProviderIsExpired(t *testing.T) {
109 stubProvider := &stubProvider{expired: true}
110 p := &ChainProvider{
111 Providers: []Provider{
112 stubProvider,
113 },
114 }
115
116 if !p.IsExpired() {
117 t.Errorf("Expect expired to be true before any Retrieve")
118 }
119 _, err := p.Retrieve()
120 if err != nil {
121 t.Errorf("Expect no error, got %v", err)
122 }
123 if p.IsExpired() {
124 t.Errorf("Expect not expired after retrieve")
125 }
126
127 stubProvider.expired = true
128 if !p.IsExpired() {
129 t.Errorf("Expect return of expired provider")
130 }
131
132 _, err = p.Retrieve()
133 if err != nil {
134 t.Errorf("Expect no error, got %v", err)
135 }
136 if p.IsExpired() {
137 t.Errorf("Expect not expired after retrieve")
138 }
139 }
140
141 func TestChainProviderWithNoProvider(t *testing.T) {
142 p := &ChainProvider{
143 Providers: []Provider{},
144 }
145
146 if !p.IsExpired() {
147 t.Errorf("Expect expired with no providers")
148 }
149 _, err := p.Retrieve()
150 if err.Error() != "NoCredentialProviders: no valid providers in chain" {
151 t.Errorf("Expect no providers error returned, got %v", err)
152 }
153 }
154
155 func TestChainProviderWithNoValidProvider(t *testing.T) {
156 errs := []error{
157 awserr.New("FirstError", "first provider error", nil),
158 awserr.New("SecondError", "second provider error", nil),
159 }
160 p := &ChainProvider{
161 Providers: []Provider{
162 &stubProvider{err: errs[0]},
163 &stubProvider{err: errs[1]},
164 },
165 }
166
167 if !p.IsExpired() {
168 t.Errorf("Expect expired with no providers")
169 }
170 _, err := p.Retrieve()
171
172 expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
173 if e, a := expectErr, err; !reflect.DeepEqual(e, a) {
174 t.Errorf("Expect no providers error returned, %v, got %v", e, a)
175 }
176 }
177
View as plain text