1
2
3
4
5
6
7 package credproviders
8
9 import (
10 "context"
11 "encoding/json"
12 "errors"
13 "fmt"
14 "io/ioutil"
15 "net/http"
16 "time"
17
18 "go.mongodb.org/mongo-driver/internal/aws/credentials"
19 )
20
21 const (
22
23 ec2ProviderName = "EC2Provider"
24
25 awsEC2URI = "http://169.254.169.254/"
26 awsEC2RolePath = "latest/meta-data/iam/security-credentials/"
27 awsEC2TokenPath = "latest/api/token"
28
29 defaultHTTPTimeout = 10 * time.Second
30 )
31
32
33 type EC2Provider struct {
34 httpClient *http.Client
35 expiration time.Time
36
37
38
39
40
41
42 expiryWindow time.Duration
43 }
44
45
46 func NewEC2Provider(httpClient *http.Client, expiryWindow time.Duration) *EC2Provider {
47 return &EC2Provider{
48 httpClient: httpClient,
49 expiryWindow: expiryWindow,
50 }
51 }
52
53 func (e *EC2Provider) getToken(ctx context.Context) (string, error) {
54 req, err := http.NewRequest(http.MethodPut, awsEC2URI+awsEC2TokenPath, nil)
55 if err != nil {
56 return "", err
57 }
58 const defaultEC2TTLSeconds = "30"
59 req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", defaultEC2TTLSeconds)
60
61 ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
62 defer cancel()
63 resp, err := e.httpClient.Do(req.WithContext(ctx))
64 if err != nil {
65 return "", err
66 }
67 defer resp.Body.Close()
68 if resp.StatusCode != http.StatusOK {
69 return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
70 }
71
72 token, err := ioutil.ReadAll(resp.Body)
73 if err != nil {
74 return "", err
75 }
76 if len(token) == 0 {
77 return "", errors.New("unable to retrieve token from EC2 metadata")
78 }
79 return string(token), nil
80 }
81
82 func (e *EC2Provider) getRoleName(ctx context.Context, token string) (string, error) {
83 req, err := http.NewRequest(http.MethodGet, awsEC2URI+awsEC2RolePath, nil)
84 if err != nil {
85 return "", err
86 }
87 req.Header.Set("X-aws-ec2-metadata-token", token)
88
89 ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
90 defer cancel()
91 resp, err := e.httpClient.Do(req.WithContext(ctx))
92 if err != nil {
93 return "", err
94 }
95 defer resp.Body.Close()
96 if resp.StatusCode != http.StatusOK {
97 return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
98 }
99
100 role, err := ioutil.ReadAll(resp.Body)
101 if err != nil {
102 return "", err
103 }
104 if len(role) == 0 {
105 return "", errors.New("unable to retrieve role_name from EC2 metadata")
106 }
107 return string(role), nil
108 }
109
110 func (e *EC2Provider) getCredentials(ctx context.Context, token string, role string) (credentials.Value, time.Time, error) {
111 v := credentials.Value{ProviderName: ec2ProviderName}
112
113 pathWithRole := awsEC2URI + awsEC2RolePath + role
114 req, err := http.NewRequest(http.MethodGet, pathWithRole, nil)
115 if err != nil {
116 return v, time.Time{}, err
117 }
118 req.Header.Set("X-aws-ec2-metadata-token", token)
119 ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
120 defer cancel()
121 resp, err := e.httpClient.Do(req.WithContext(ctx))
122 if err != nil {
123 return v, time.Time{}, err
124 }
125 defer resp.Body.Close()
126 if resp.StatusCode != http.StatusOK {
127 return v, time.Time{}, fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
128 }
129
130 var ec2Resp struct {
131 AccessKeyID string `json:"AccessKeyId"`
132 SecretAccessKey string `json:"SecretAccessKey"`
133 Token string `json:"Token"`
134 Expiration time.Time `json:"Expiration"`
135 }
136
137 err = json.NewDecoder(resp.Body).Decode(&ec2Resp)
138 if err != nil {
139 return v, time.Time{}, err
140 }
141
142 v.AccessKeyID = ec2Resp.AccessKeyID
143 v.SecretAccessKey = ec2Resp.SecretAccessKey
144 v.SessionToken = ec2Resp.Token
145
146 return v, ec2Resp.Expiration, nil
147 }
148
149
150 func (e *EC2Provider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
151 v := credentials.Value{ProviderName: ec2ProviderName}
152
153 token, err := e.getToken(ctx)
154 if err != nil {
155 return v, err
156 }
157
158 role, err := e.getRoleName(ctx, token)
159 if err != nil {
160 return v, err
161 }
162
163 v, exp, err := e.getCredentials(ctx, token, role)
164 if err != nil {
165 return v, err
166 }
167 if !v.HasKeys() {
168 return v, errors.New("failed to retrieve EC2 keys")
169 }
170 e.expiration = exp.Add(-e.expiryWindow)
171
172 return v, nil
173 }
174
175
176 func (e *EC2Provider) Retrieve() (credentials.Value, error) {
177 return e.RetrieveWithContext(context.Background())
178 }
179
180
181 func (e *EC2Provider) IsExpired() bool {
182 return e.expiration.Before(time.Now())
183 }
184
View as plain text