...
1
2
3
4
5
6
7 package credproviders
8
9 import (
10 "context"
11 "encoding/json"
12 "errors"
13 "fmt"
14 "io/ioutil"
15 "net/http"
16 "time"
17
18 "go.mongodb.org/mongo-driver/internal/aws/credentials"
19 "go.mongodb.org/mongo-driver/internal/uuid"
20 )
21
22 const (
23
24 assumeRoleProviderName = "AssumeRoleProvider"
25
26 stsURI = `https://sts.amazonaws.com/?Action=AssumeRoleWithWebIdentity&RoleSessionName=%s&RoleArn=%s&WebIdentityToken=%s&Version=2011-06-15`
27 )
28
29
30 type AssumeRoleProvider struct {
31 AwsRoleArnEnv EnvVar
32 AwsWebIdentityTokenFileEnv EnvVar
33 AwsRoleSessionNameEnv EnvVar
34
35 httpClient *http.Client
36 expiration time.Time
37
38
39
40
41
42
43 expiryWindow time.Duration
44 }
45
46
47 func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration) *AssumeRoleProvider {
48 return &AssumeRoleProvider{
49
50 AwsRoleArnEnv: EnvVar("AWS_ROLE_ARN"),
51
52 AwsWebIdentityTokenFileEnv: EnvVar("AWS_WEB_IDENTITY_TOKEN_FILE"),
53
54 AwsRoleSessionNameEnv: EnvVar("AWS_ROLE_SESSION_NAME"),
55 httpClient: httpClient,
56 expiryWindow: expiryWindow,
57 }
58 }
59
60
61 func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
62 const defaultHTTPTimeout = 10 * time.Second
63
64 v := credentials.Value{ProviderName: assumeRoleProviderName}
65
66 roleArn := a.AwsRoleArnEnv.Get()
67 tokenFile := a.AwsWebIdentityTokenFileEnv.Get()
68 if tokenFile == "" && roleArn == "" {
69 return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are missing")
70 }
71 if tokenFile != "" && roleArn == "" {
72 return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE is set, but AWS_ROLE_ARN is missing")
73 }
74 if tokenFile == "" && roleArn != "" {
75 return v, errors.New("AWS_ROLE_ARN is set, but AWS_WEB_IDENTITY_TOKEN_FILE is missing")
76 }
77 token, err := ioutil.ReadFile(tokenFile)
78 if err != nil {
79 return v, err
80 }
81
82 sessionName := a.AwsRoleSessionNameEnv.Get()
83 if sessionName == "" {
84
85 id, err := uuid.New()
86 if err != nil {
87 return v, err
88 }
89 sessionName = id.String()
90 }
91
92 fullURI := fmt.Sprintf(stsURI, sessionName, roleArn, string(token))
93
94 req, err := http.NewRequest(http.MethodPost, fullURI, nil)
95 if err != nil {
96 return v, err
97 }
98 req.Header.Set("Accept", "application/json")
99
100 ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
101 defer cancel()
102 resp, err := a.httpClient.Do(req.WithContext(ctx))
103 if err != nil {
104 return v, err
105 }
106 defer resp.Body.Close()
107 if resp.StatusCode != http.StatusOK {
108 return v, fmt.Errorf("response failure: %s", resp.Status)
109 }
110
111 var stsResp struct {
112 Response struct {
113 Result struct {
114 Credentials struct {
115 AccessKeyID string `json:"AccessKeyId"`
116 SecretAccessKey string `json:"SecretAccessKey"`
117 Token string `json:"SessionToken"`
118 Expiration float64 `json:"Expiration"`
119 } `json:"Credentials"`
120 } `json:"AssumeRoleWithWebIdentityResult"`
121 } `json:"AssumeRoleWithWebIdentityResponse"`
122 }
123
124 err = json.NewDecoder(resp.Body).Decode(&stsResp)
125 if err != nil {
126 return v, err
127 }
128 v.AccessKeyID = stsResp.Response.Result.Credentials.AccessKeyID
129 v.SecretAccessKey = stsResp.Response.Result.Credentials.SecretAccessKey
130 v.SessionToken = stsResp.Response.Result.Credentials.Token
131 if !v.HasKeys() {
132 return v, errors.New("failed to retrieve web identity keys")
133 }
134 sec := int64(stsResp.Response.Result.Credentials.Expiration)
135 a.expiration = time.Unix(sec, 0).Add(-a.expiryWindow)
136
137 return v, nil
138 }
139
140
141 func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
142 return a.RetrieveWithContext(context.Background())
143 }
144
145
146 func (a *AssumeRoleProvider) IsExpired() bool {
147 return a.expiration.Before(time.Now())
148 }
149
View as plain text