...

Source file src/go.mongodb.org/mongo-driver/internal/credproviders/imds_provider.go

Documentation: go.mongodb.org/mongo-driver/internal/credproviders

     1  // Copyright (C) MongoDB, Inc. 2023-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package credproviders
     8  
     9  import (
    10  	"context"
    11  	"encoding/json"
    12  	"fmt"
    13  	"io/ioutil"
    14  	"net/http"
    15  	"net/url"
    16  	"time"
    17  
    18  	"go.mongodb.org/mongo-driver/internal/aws/credentials"
    19  )
    20  
    21  const (
    22  	// AzureProviderName provides a name of Azure provider
    23  	AzureProviderName = "AzureProvider"
    24  
    25  	azureURI = "http://169.254.169.254/metadata/identity/oauth2/token"
    26  )
    27  
    28  // An AzureProvider retrieves credentials from Azure IMDS.
    29  type AzureProvider struct {
    30  	httpClient   *http.Client
    31  	expiration   time.Time
    32  	expiryWindow time.Duration
    33  }
    34  
    35  // NewAzureProvider returns a pointer to an Azure credential provider.
    36  func NewAzureProvider(httpClient *http.Client, expiryWindow time.Duration) *AzureProvider {
    37  	return &AzureProvider{
    38  		httpClient:   httpClient,
    39  		expiration:   time.Time{},
    40  		expiryWindow: expiryWindow,
    41  	}
    42  }
    43  
    44  // RetrieveWithContext retrieves the keys from the Azure service.
    45  func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
    46  	v := credentials.Value{ProviderName: AzureProviderName}
    47  	req, err := http.NewRequest(http.MethodGet, azureURI, nil)
    48  	if err != nil {
    49  		return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err)
    50  	}
    51  	q := make(url.Values)
    52  	q.Set("api-version", "2018-02-01")
    53  	q.Set("resource", "https://vault.azure.net")
    54  	req.URL.RawQuery = q.Encode()
    55  	req.Header.Set("Metadata", "true")
    56  	req.Header.Set("Accept", "application/json")
    57  
    58  	resp, err := a.httpClient.Do(req.WithContext(ctx))
    59  	if err != nil {
    60  		return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err)
    61  	}
    62  	defer resp.Body.Close()
    63  	body, err := ioutil.ReadAll(resp.Body)
    64  	if err != nil {
    65  		return v, fmt.Errorf("unable to retrieve Azure credentials: error reading response body: %w", err)
    66  	}
    67  	if resp.StatusCode != http.StatusOK {
    68  		return v, fmt.Errorf("unable to retrieve Azure credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body)
    69  	}
    70  	var tokenResponse struct {
    71  		AccessToken string `json:"access_token"`
    72  		ExpiresIn   string `json:"expires_in"`
    73  	}
    74  	// Attempt to read body as JSON
    75  	err = json.Unmarshal(body, &tokenResponse)
    76  	if err != nil {
    77  		return v, fmt.Errorf("unable to retrieve Azure credentials: error reading body JSON: %w (response body: %s)", err, body)
    78  	}
    79  	if tokenResponse.AccessToken == "" {
    80  		return v, fmt.Errorf("unable to retrieve Azure credentials: got unexpected empty accessToken from Azure Metadata Server. Response body: %s", body)
    81  	}
    82  	v.SessionToken = tokenResponse.AccessToken
    83  
    84  	expiresIn, err := time.ParseDuration(tokenResponse.ExpiresIn + "s")
    85  	if err != nil {
    86  		return v, err
    87  	}
    88  	if expiration := expiresIn - a.expiryWindow; expiration > 0 {
    89  		a.expiration = time.Now().Add(expiration)
    90  	}
    91  
    92  	return v, err
    93  }
    94  
    95  // Retrieve retrieves the keys from the Azure service.
    96  func (a *AzureProvider) Retrieve() (credentials.Value, error) {
    97  	return a.RetrieveWithContext(context.Background())
    98  }
    99  
   100  // IsExpired returns if the credentials have been retrieved.
   101  func (a *AzureProvider) IsExpired() bool {
   102  	return a.expiration.Before(time.Now())
   103  }
   104  

View as plain text