...
1
2
3
4
5
6
7 package creds
8
9 import (
10 "context"
11 "fmt"
12 "io"
13 "net/http"
14 "net/http/httptest"
15 "net/url"
16 "sync/atomic"
17 "testing"
18 "time"
19
20 "go.mongodb.org/mongo-driver/internal/assert"
21 "go.mongodb.org/mongo-driver/internal/aws/credentials"
22 "go.mongodb.org/mongo-driver/internal/credproviders"
23 )
24
25 type pipeTransport struct {
26 url string
27 param string
28 client *http.Client
29 }
30
31
32 func (t pipeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
33 uri, err := url.Parse(t.url)
34 if err != nil {
35 return nil, err
36 }
37 values := uri.Query()
38 values.Add(t.param, req.URL.String())
39 uri.RawQuery = values.Encode()
40 req.URL = uri
41 return t.client.Do(req)
42 }
43
44 func TestAWSCredentialProviderCaching(t *testing.T) {
45 const (
46 urienv = "TEST_CONTAINER_CREDENTIALS_RELATIVE_URI"
47 keyenv = "TEST_ACCESS_KEY"
48 awsRelativeURI = "http://169.254.170.2/"
49 testEndpoint = "foo"
50 param = "source"
51 )
52
53 t.Setenv(urienv, testEndpoint)
54
55 testCases := []struct {
56 expiration time.Duration
57 reqCount uint32
58 }{
59 {
60 expiration: 20 * time.Minute,
61 reqCount: 1,
62 },
63 {
64 expiration: 5 * time.Minute,
65 reqCount: 2,
66 },
67 {
68 expiration: -1 * time.Minute,
69 reqCount: 2,
70 },
71 }
72 for _, tc := range testCases {
73 t.Run(fmt.Sprintf("expires in %s", tc.expiration.String()), func(t *testing.T) {
74 var cnt uint32
75
76 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
77 if r.URL.Query().Get(param) != awsRelativeURI+testEndpoint {
78 w.WriteHeader(http.StatusNotFound)
79 return
80 }
81 atomic.AddUint32(&cnt, 1)
82 t := time.Now().Add(tc.expiration).Format(time.RFC3339)
83 _, err := io.WriteString(w, fmt.Sprintf(`{
84 "AccessKeyId": "id",
85 "SecretAccessKey": "key",
86 "Token": "token",
87 "Expiration": "%s"
88 }`, t))
89 if err != nil {
90 w.WriteHeader(http.StatusInternalServerError)
91 }
92 }))
93 defer ts.Close()
94
95 client := &http.Client{
96 Transport: pipeTransport{
97 url: ts.URL,
98 param: param,
99 client: ts.Client(),
100 },
101 }
102
103 env := credproviders.NewEnvProvider()
104 env.AwsAccessKeyIDEnv = credproviders.EnvVar(keyenv)
105 ecs := credproviders.NewECSProvider(client, expiryWindow)
106 ecs.AwsContainerCredentialsRelativeURIEnv = credproviders.EnvVar(urienv)
107
108 p := AWSCredentialProvider{credentials.NewChainCredentials([]credentials.Provider{env, ecs})}
109 var err error
110 _, err = p.GetCredentialsDoc(context.Background())
111 assert.NoError(t, err, "error in GetCredentialsDoc")
112 _, err = p.GetCredentialsDoc(context.Background())
113 assert.NoError(t, err, "error in GetCredentialsDoc")
114 assert.Equal(t, tc.reqCount, atomic.LoadUint32(&cnt), "expected and actual credentials retrieval count don't match")
115 })
116 }
117 }
118
View as plain text