...
1
2
3
4
5
6
7 package credproviders
8
9 import (
10 "context"
11 "encoding/json"
12 "fmt"
13 "io/ioutil"
14 "net/http"
15 "net/url"
16 "time"
17
18 "go.mongodb.org/mongo-driver/internal/aws/credentials"
19 )
20
21 const (
22
23 AzureProviderName = "AzureProvider"
24
25 azureURI = "http://169.254.169.254/metadata/identity/oauth2/token"
26 )
27
28
29 type AzureProvider struct {
30 httpClient *http.Client
31 expiration time.Time
32 expiryWindow time.Duration
33 }
34
35
36 func NewAzureProvider(httpClient *http.Client, expiryWindow time.Duration) *AzureProvider {
37 return &AzureProvider{
38 httpClient: httpClient,
39 expiration: time.Time{},
40 expiryWindow: expiryWindow,
41 }
42 }
43
44
45 func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
46 v := credentials.Value{ProviderName: AzureProviderName}
47 req, err := http.NewRequest(http.MethodGet, azureURI, nil)
48 if err != nil {
49 return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err)
50 }
51 q := make(url.Values)
52 q.Set("api-version", "2018-02-01")
53 q.Set("resource", "https://vault.azure.net")
54 req.URL.RawQuery = q.Encode()
55 req.Header.Set("Metadata", "true")
56 req.Header.Set("Accept", "application/json")
57
58 resp, err := a.httpClient.Do(req.WithContext(ctx))
59 if err != nil {
60 return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err)
61 }
62 defer resp.Body.Close()
63 body, err := ioutil.ReadAll(resp.Body)
64 if err != nil {
65 return v, fmt.Errorf("unable to retrieve Azure credentials: error reading response body: %w", err)
66 }
67 if resp.StatusCode != http.StatusOK {
68 return v, fmt.Errorf("unable to retrieve Azure credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body)
69 }
70 var tokenResponse struct {
71 AccessToken string `json:"access_token"`
72 ExpiresIn string `json:"expires_in"`
73 }
74
75 err = json.Unmarshal(body, &tokenResponse)
76 if err != nil {
77 return v, fmt.Errorf("unable to retrieve Azure credentials: error reading body JSON: %w (response body: %s)", err, body)
78 }
79 if tokenResponse.AccessToken == "" {
80 return v, fmt.Errorf("unable to retrieve Azure credentials: got unexpected empty accessToken from Azure Metadata Server. Response body: %s", body)
81 }
82 v.SessionToken = tokenResponse.AccessToken
83
84 expiresIn, err := time.ParseDuration(tokenResponse.ExpiresIn + "s")
85 if err != nil {
86 return v, err
87 }
88 if expiration := expiresIn - a.expiryWindow; expiration > 0 {
89 a.expiration = time.Now().Add(expiration)
90 }
91
92 return v, err
93 }
94
95
96 func (a *AzureProvider) Retrieve() (credentials.Value, error) {
97 return a.RetrieveWithContext(context.Background())
98 }
99
100
101 func (a *AzureProvider) IsExpired() bool {
102 return a.expiration.Before(time.Now())
103 }
104
View as plain text