1
2
3
4
5
6
7
8
9
10
11 package credentials
12
13 import (
14 "context"
15 "sync"
16 "testing"
17 "time"
18
19 "go.mongodb.org/mongo-driver/internal/aws/awserr"
20 )
21
22 func isExpired(c *Credentials) bool {
23 c.m.RLock()
24 defer c.m.RUnlock()
25
26 return c.isExpiredLocked(c.creds)
27 }
28
29 type stubProvider struct {
30 creds Value
31 retrievedCount int
32 expired bool
33 err error
34 }
35
36 func (s *stubProvider) Retrieve() (Value, error) {
37 s.retrievedCount++
38 s.expired = false
39 s.creds.ProviderName = "stubProvider"
40 return s.creds, s.err
41 }
42 func (s *stubProvider) IsExpired() bool {
43 return s.expired
44 }
45
46 func TestCredentialsGet(t *testing.T) {
47 c := NewCredentials(&stubProvider{
48 creds: Value{
49 AccessKeyID: "AKID",
50 SecretAccessKey: "SECRET",
51 SessionToken: "",
52 },
53 expired: true,
54 })
55
56 creds, err := c.GetWithContext(context.Background())
57 if err != nil {
58 t.Errorf("Expected no error, got %v", err)
59 }
60 if e, a := "AKID", creds.AccessKeyID; e != a {
61 t.Errorf("Expect access key ID to match, %v got %v", e, a)
62 }
63 if e, a := "SECRET", creds.SecretAccessKey; e != a {
64 t.Errorf("Expect secret access key to match, %v got %v", e, a)
65 }
66 if v := creds.SessionToken; len(v) != 0 {
67 t.Errorf("Expect session token to be empty, %v", v)
68 }
69 }
70
71 func TestCredentialsGetWithError(t *testing.T) {
72 c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
73
74 _, err := c.GetWithContext(context.Background())
75 if e, a := "provider error", err.(awserr.Error).Code(); e != a {
76 t.Errorf("Expected provider error, %v got %v", e, a)
77 }
78 }
79
80 func TestCredentialsExpire(t *testing.T) {
81 stub := &stubProvider{}
82 c := NewCredentials(stub)
83
84 stub.expired = false
85 if !isExpired(c) {
86 t.Errorf("Expected to start out expired")
87 }
88
89 _, err := c.GetWithContext(context.Background())
90 if err != nil {
91 t.Errorf("Expected no err, got %v", err)
92 }
93 if isExpired(c) {
94 t.Errorf("Expected not to be expired")
95 }
96
97 stub.expired = true
98 if !isExpired(c) {
99 t.Errorf("Expected to be expired")
100 }
101 }
102
103 func TestCredentialsGetWithProviderName(t *testing.T) {
104 stub := &stubProvider{}
105
106 c := NewCredentials(stub)
107
108 creds, err := c.GetWithContext(context.Background())
109 if err != nil {
110 t.Errorf("Expected no error, got %v", err)
111 }
112 if e, a := creds.ProviderName, "stubProvider"; e != a {
113 t.Errorf("Expected provider name to match, %v got %v", e, a)
114 }
115 }
116
117 type MockProvider struct {
118
119 expiration time.Time
120
121
122
123
124 CurrentTime func() time.Time
125 }
126
127
128 func (e *MockProvider) IsExpired() bool {
129 curTime := e.CurrentTime
130 if curTime == nil {
131 curTime = time.Now
132 }
133 return e.expiration.Before(curTime())
134 }
135
136 func (*MockProvider) Retrieve() (Value, error) {
137 return Value{}, nil
138 }
139
140 func TestCredentialsIsExpired_Race(_ *testing.T) {
141 creds := NewChainCredentials([]Provider{&MockProvider{}})
142
143 starter := make(chan struct{})
144 var wg sync.WaitGroup
145 wg.Add(10)
146 for i := 0; i < 10; i++ {
147 go func() {
148 defer wg.Done()
149 <-starter
150 for i := 0; i < 100; i++ {
151 isExpired(creds)
152 }
153 }()
154 }
155 close(starter)
156
157 wg.Wait()
158 }
159
160 type stubProviderConcurrent struct {
161 stubProvider
162 done chan struct{}
163 }
164
165 func (s *stubProviderConcurrent) Retrieve() (Value, error) {
166 <-s.done
167 return s.stubProvider.Retrieve()
168 }
169
170 func TestCredentialsGetConcurrent(t *testing.T) {
171 stub := &stubProviderConcurrent{
172 done: make(chan struct{}),
173 }
174
175 c := NewCredentials(stub)
176 done := make(chan struct{})
177
178 for i := 0; i < 2; i++ {
179 go func() {
180 _, err := c.GetWithContext(context.Background())
181 if err != nil {
182 t.Errorf("Expected no err, got %v", err)
183 }
184 done <- struct{}{}
185 }()
186 }
187
188
189 stub.done <- struct{}{}
190 <-done
191 <-done
192 }
193
View as plain text