...

Source file src/github.com/Azure/go-autorest/autorest/adal/token.go

Documentation: github.com/Azure/go-autorest/autorest/adal

     1  package adal
     2  
     3  // Copyright 2017 Microsoft Corporation
     4  //
     5  //  Licensed under the Apache License, Version 2.0 (the "License");
     6  //  you may not use this file except in compliance with the License.
     7  //  You may obtain a copy of the License at
     8  //
     9  //      http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  //  Unless required by applicable law or agreed to in writing, software
    12  //  distributed under the License is distributed on an "AS IS" BASIS,
    13  //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  //  See the License for the specific language governing permissions and
    15  //  limitations under the License.
    16  
    17  import (
    18  	"context"
    19  	"crypto/rand"
    20  	"crypto/rsa"
    21  	"crypto/sha1"
    22  	"crypto/x509"
    23  	"encoding/base64"
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"io/ioutil"
    29  	"math"
    30  	"net/http"
    31  	"net/url"
    32  	"os"
    33  	"strconv"
    34  	"strings"
    35  	"sync"
    36  	"time"
    37  
    38  	"github.com/Azure/go-autorest/autorest/date"
    39  	"github.com/Azure/go-autorest/logger"
    40  	"github.com/golang-jwt/jwt/v4"
    41  )
    42  
    43  const (
    44  	defaultRefresh = 5 * time.Minute
    45  
    46  	// OAuthGrantTypeDeviceCode is the "grant_type" identifier used in device flow
    47  	OAuthGrantTypeDeviceCode = "device_code"
    48  
    49  	// OAuthGrantTypeClientCredentials is the "grant_type" identifier used in credential flows
    50  	OAuthGrantTypeClientCredentials = "client_credentials"
    51  
    52  	// OAuthGrantTypeUserPass is the "grant_type" identifier used in username and password auth flows
    53  	OAuthGrantTypeUserPass = "password"
    54  
    55  	// OAuthGrantTypeRefreshToken is the "grant_type" identifier used in refresh token flows
    56  	OAuthGrantTypeRefreshToken = "refresh_token"
    57  
    58  	// OAuthGrantTypeAuthorizationCode is the "grant_type" identifier used in authorization code flows
    59  	OAuthGrantTypeAuthorizationCode = "authorization_code"
    60  
    61  	// metadataHeader is the header required by MSI extension
    62  	metadataHeader = "Metadata"
    63  
    64  	// msiEndpoint is the well known endpoint for getting MSI authentications tokens
    65  	msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
    66  
    67  	// the API version to use for the MSI endpoint
    68  	msiAPIVersion = "2018-02-01"
    69  
    70  	// the default number of attempts to refresh an MSI authentication token
    71  	defaultMaxMSIRefreshAttempts = 5
    72  
    73  	// asMSIEndpointEnv is the environment variable used to store the endpoint on App Service and Functions
    74  	msiEndpointEnv = "MSI_ENDPOINT"
    75  
    76  	// asMSISecretEnv is the environment variable used to store the request secret on App Service and Functions
    77  	msiSecretEnv = "MSI_SECRET"
    78  
    79  	// the API version to use for the legacy App Service MSI endpoint
    80  	appServiceAPIVersion2017 = "2017-09-01"
    81  
    82  	// secret header used when authenticating against app service MSI endpoint
    83  	secretHeader = "Secret"
    84  
    85  	// the format for expires_on in UTC with AM/PM
    86  	expiresOnDateFormatPM = "1/2/2006 15:04:05 PM +00:00"
    87  
    88  	// the format for expires_on in UTC without AM/PM
    89  	expiresOnDateFormat = "1/2/2006 15:04:05 +00:00"
    90  )
    91  
    92  // OAuthTokenProvider is an interface which should be implemented by an access token retriever
    93  type OAuthTokenProvider interface {
    94  	OAuthToken() string
    95  }
    96  
    97  // MultitenantOAuthTokenProvider provides tokens used for multi-tenant authorization.
    98  type MultitenantOAuthTokenProvider interface {
    99  	PrimaryOAuthToken() string
   100  	AuxiliaryOAuthTokens() []string
   101  }
   102  
   103  // TokenRefreshError is an interface used by errors returned during token refresh.
   104  type TokenRefreshError interface {
   105  	error
   106  	Response() *http.Response
   107  }
   108  
   109  // Refresher is an interface for token refresh functionality
   110  type Refresher interface {
   111  	Refresh() error
   112  	RefreshExchange(resource string) error
   113  	EnsureFresh() error
   114  }
   115  
   116  // RefresherWithContext is an interface for token refresh functionality
   117  type RefresherWithContext interface {
   118  	RefreshWithContext(ctx context.Context) error
   119  	RefreshExchangeWithContext(ctx context.Context, resource string) error
   120  	EnsureFreshWithContext(ctx context.Context) error
   121  }
   122  
   123  // TokenRefreshCallback is the type representing callbacks that will be called after
   124  // a successful token refresh
   125  type TokenRefreshCallback func(Token) error
   126  
   127  // TokenRefresh is a type representing a custom callback to refresh a token
   128  type TokenRefresh func(ctx context.Context, resource string) (*Token, error)
   129  
   130  // JWTCallback is the type representing callback that will be called to get the federated OIDC JWT
   131  type JWTCallback func() (string, error)
   132  
   133  // Token encapsulates the access token used to authorize Azure requests.
   134  // https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
   135  type Token struct {
   136  	AccessToken  string `json:"access_token"`
   137  	RefreshToken string `json:"refresh_token"`
   138  
   139  	ExpiresIn json.Number `json:"expires_in"`
   140  	ExpiresOn json.Number `json:"expires_on"`
   141  	NotBefore json.Number `json:"not_before"`
   142  
   143  	Resource string `json:"resource"`
   144  	Type     string `json:"token_type"`
   145  }
   146  
   147  func newToken() Token {
   148  	return Token{
   149  		ExpiresIn: "0",
   150  		ExpiresOn: "0",
   151  		NotBefore: "0",
   152  	}
   153  }
   154  
   155  // IsZero returns true if the token object is zero-initialized.
   156  func (t Token) IsZero() bool {
   157  	return t == Token{}
   158  }
   159  
   160  // Expires returns the time.Time when the Token expires.
   161  func (t Token) Expires() time.Time {
   162  	s, err := t.ExpiresOn.Float64()
   163  	if err != nil {
   164  		s = -3600
   165  	}
   166  
   167  	expiration := date.NewUnixTimeFromSeconds(s)
   168  
   169  	return time.Time(expiration).UTC()
   170  }
   171  
   172  // IsExpired returns true if the Token is expired, false otherwise.
   173  func (t Token) IsExpired() bool {
   174  	return t.WillExpireIn(0)
   175  }
   176  
   177  // WillExpireIn returns true if the Token will expire after the passed time.Duration interval
   178  // from now, false otherwise.
   179  func (t Token) WillExpireIn(d time.Duration) bool {
   180  	return !t.Expires().After(time.Now().Add(d))
   181  }
   182  
   183  // OAuthToken return the current access token
   184  func (t *Token) OAuthToken() string {
   185  	return t.AccessToken
   186  }
   187  
   188  // ServicePrincipalSecret is an interface that allows various secret mechanism to fill the form
   189  // that is submitted when acquiring an oAuth token.
   190  type ServicePrincipalSecret interface {
   191  	SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
   192  }
   193  
   194  // ServicePrincipalNoSecret represents a secret type that contains no secret
   195  // meaning it is not valid for fetching a fresh token. This is used by Manual
   196  type ServicePrincipalNoSecret struct {
   197  }
   198  
   199  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret
   200  // It only returns an error for the ServicePrincipalNoSecret type
   201  func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
   202  	return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
   203  }
   204  
   205  // MarshalJSON implements the json.Marshaler interface.
   206  func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
   207  	type tokenType struct {
   208  		Type string `json:"type"`
   209  	}
   210  	return json.Marshal(tokenType{
   211  		Type: "ServicePrincipalNoSecret",
   212  	})
   213  }
   214  
   215  // ServicePrincipalTokenSecret implements ServicePrincipalSecret for client_secret type authorization.
   216  type ServicePrincipalTokenSecret struct {
   217  	ClientSecret string `json:"value"`
   218  }
   219  
   220  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
   221  // It will populate the form submitted during oAuth Token Acquisition using the client_secret.
   222  func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
   223  	v.Set("client_secret", tokenSecret.ClientSecret)
   224  	return nil
   225  }
   226  
   227  // MarshalJSON implements the json.Marshaler interface.
   228  func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
   229  	type tokenType struct {
   230  		Type  string `json:"type"`
   231  		Value string `json:"value"`
   232  	}
   233  	return json.Marshal(tokenType{
   234  		Type:  "ServicePrincipalTokenSecret",
   235  		Value: tokenSecret.ClientSecret,
   236  	})
   237  }
   238  
   239  // ServicePrincipalCertificateSecret implements ServicePrincipalSecret for generic RSA cert auth with signed JWTs.
   240  type ServicePrincipalCertificateSecret struct {
   241  	Certificate *x509.Certificate
   242  	PrivateKey  *rsa.PrivateKey
   243  }
   244  
   245  // SignJwt returns the JWT signed with the certificate's private key.
   246  func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
   247  	hasher := sha1.New()
   248  	_, err := hasher.Write(secret.Certificate.Raw)
   249  	if err != nil {
   250  		return "", err
   251  	}
   252  
   253  	thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
   254  
   255  	// The jti (JWT ID) claim provides a unique identifier for the JWT.
   256  	jti := make([]byte, 20)
   257  	_, err = rand.Read(jti)
   258  	if err != nil {
   259  		return "", err
   260  	}
   261  
   262  	token := jwt.New(jwt.SigningMethodRS256)
   263  	token.Header["x5t"] = thumbprint
   264  	x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
   265  	token.Header["x5c"] = x5c
   266  	token.Claims = jwt.MapClaims{
   267  		"aud": spt.inner.OauthConfig.TokenEndpoint.String(),
   268  		"iss": spt.inner.ClientID,
   269  		"sub": spt.inner.ClientID,
   270  		"jti": base64.URLEncoding.EncodeToString(jti),
   271  		"nbf": time.Now().Unix(),
   272  		"exp": time.Now().Add(24 * time.Hour).Unix(),
   273  	}
   274  
   275  	signedString, err := token.SignedString(secret.PrivateKey)
   276  	return signedString, err
   277  }
   278  
   279  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
   280  // It will populate the form submitted during oAuth Token Acquisition using a JWT signed with a certificate.
   281  func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
   282  	jwt, err := secret.SignJwt(spt)
   283  	if err != nil {
   284  		return err
   285  	}
   286  
   287  	v.Set("client_assertion", jwt)
   288  	v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
   289  	return nil
   290  }
   291  
   292  // MarshalJSON implements the json.Marshaler interface.
   293  func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
   294  	return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
   295  }
   296  
   297  // ServicePrincipalMSISecret implements ServicePrincipalSecret for machines running the MSI Extension.
   298  type ServicePrincipalMSISecret struct {
   299  	msiType          msiType
   300  	clientResourceID string
   301  }
   302  
   303  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
   304  func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
   305  	return nil
   306  }
   307  
   308  // MarshalJSON implements the json.Marshaler interface.
   309  func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
   310  	return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
   311  }
   312  
   313  // ServicePrincipalUsernamePasswordSecret implements ServicePrincipalSecret for username and password auth.
   314  type ServicePrincipalUsernamePasswordSecret struct {
   315  	Username string `json:"username"`
   316  	Password string `json:"password"`
   317  }
   318  
   319  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
   320  func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
   321  	v.Set("username", secret.Username)
   322  	v.Set("password", secret.Password)
   323  	return nil
   324  }
   325  
   326  // MarshalJSON implements the json.Marshaler interface.
   327  func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
   328  	type tokenType struct {
   329  		Type     string `json:"type"`
   330  		Username string `json:"username"`
   331  		Password string `json:"password"`
   332  	}
   333  	return json.Marshal(tokenType{
   334  		Type:     "ServicePrincipalUsernamePasswordSecret",
   335  		Username: secret.Username,
   336  		Password: secret.Password,
   337  	})
   338  }
   339  
   340  // ServicePrincipalAuthorizationCodeSecret implements ServicePrincipalSecret for authorization code auth.
   341  type ServicePrincipalAuthorizationCodeSecret struct {
   342  	ClientSecret      string `json:"value"`
   343  	AuthorizationCode string `json:"authCode"`
   344  	RedirectURI       string `json:"redirect"`
   345  }
   346  
   347  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
   348  func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
   349  	v.Set("code", secret.AuthorizationCode)
   350  	v.Set("client_secret", secret.ClientSecret)
   351  	v.Set("redirect_uri", secret.RedirectURI)
   352  	return nil
   353  }
   354  
   355  // MarshalJSON implements the json.Marshaler interface.
   356  func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
   357  	type tokenType struct {
   358  		Type     string `json:"type"`
   359  		Value    string `json:"value"`
   360  		AuthCode string `json:"authCode"`
   361  		Redirect string `json:"redirect"`
   362  	}
   363  	return json.Marshal(tokenType{
   364  		Type:     "ServicePrincipalAuthorizationCodeSecret",
   365  		Value:    secret.ClientSecret,
   366  		AuthCode: secret.AuthorizationCode,
   367  		Redirect: secret.RedirectURI,
   368  	})
   369  }
   370  
   371  // ServicePrincipalFederatedSecret implements ServicePrincipalSecret for Federated JWTs.
   372  type ServicePrincipalFederatedSecret struct {
   373  	jwtCallback JWTCallback
   374  }
   375  
   376  // SetAuthenticationValues is a method of the interface ServicePrincipalSecret.
   377  // It will populate the form submitted during OAuth Token Acquisition using a JWT signed by an OIDC issuer.
   378  func (secret *ServicePrincipalFederatedSecret) SetAuthenticationValues(_ *ServicePrincipalToken, v *url.Values) error {
   379  	jwt, err := secret.jwtCallback()
   380  	if err != nil {
   381  		return err
   382  	}
   383  
   384  	v.Set("client_assertion", jwt)
   385  	v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
   386  	return nil
   387  }
   388  
   389  // MarshalJSON implements the json.Marshaler interface.
   390  func (secret ServicePrincipalFederatedSecret) MarshalJSON() ([]byte, error) {
   391  	return nil, errors.New("marshalling ServicePrincipalFederatedSecret is not supported")
   392  }
   393  
   394  // ServicePrincipalToken encapsulates a Token created for a Service Principal.
   395  type ServicePrincipalToken struct {
   396  	inner             servicePrincipalToken
   397  	refreshLock       *sync.RWMutex
   398  	sender            Sender
   399  	customRefreshFunc TokenRefresh
   400  	refreshCallbacks  []TokenRefreshCallback
   401  	// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
   402  	// Settings this to a value less than 1 will use the default value.
   403  	MaxMSIRefreshAttempts int
   404  }
   405  
   406  // MarshalTokenJSON returns the marshalled inner token.
   407  func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
   408  	return json.Marshal(spt.inner.Token)
   409  }
   410  
   411  // SetRefreshCallbacks replaces any existing refresh callbacks with the specified callbacks.
   412  func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
   413  	spt.refreshCallbacks = callbacks
   414  }
   415  
   416  // SetCustomRefreshFunc sets a custom refresh function used to refresh the token.
   417  func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) {
   418  	spt.customRefreshFunc = customRefreshFunc
   419  }
   420  
   421  // MarshalJSON implements the json.Marshaler interface.
   422  func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
   423  	return json.Marshal(spt.inner)
   424  }
   425  
   426  // UnmarshalJSON implements the json.Unmarshaler interface.
   427  func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
   428  	// need to determine the token type
   429  	raw := map[string]interface{}{}
   430  	err := json.Unmarshal(data, &raw)
   431  	if err != nil {
   432  		return err
   433  	}
   434  	secret := raw["secret"].(map[string]interface{})
   435  	switch secret["type"] {
   436  	case "ServicePrincipalNoSecret":
   437  		spt.inner.Secret = &ServicePrincipalNoSecret{}
   438  	case "ServicePrincipalTokenSecret":
   439  		spt.inner.Secret = &ServicePrincipalTokenSecret{}
   440  	case "ServicePrincipalCertificateSecret":
   441  		return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
   442  	case "ServicePrincipalMSISecret":
   443  		return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
   444  	case "ServicePrincipalUsernamePasswordSecret":
   445  		spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
   446  	case "ServicePrincipalAuthorizationCodeSecret":
   447  		spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
   448  	case "ServicePrincipalFederatedSecret":
   449  		return errors.New("unmarshalling ServicePrincipalFederatedSecret is not supported")
   450  	default:
   451  		return fmt.Errorf("unrecognized token type '%s'", secret["type"])
   452  	}
   453  	err = json.Unmarshal(data, &spt.inner)
   454  	if err != nil {
   455  		return err
   456  	}
   457  	// Don't override the refreshLock or the sender if those have been already set.
   458  	if spt.refreshLock == nil {
   459  		spt.refreshLock = &sync.RWMutex{}
   460  	}
   461  	if spt.sender == nil {
   462  		spt.sender = sender()
   463  	}
   464  	return nil
   465  }
   466  
   467  // internal type used for marshalling/unmarshalling
   468  type servicePrincipalToken struct {
   469  	Token         Token                  `json:"token"`
   470  	Secret        ServicePrincipalSecret `json:"secret"`
   471  	OauthConfig   OAuthConfig            `json:"oauth"`
   472  	ClientID      string                 `json:"clientID"`
   473  	Resource      string                 `json:"resource"`
   474  	AutoRefresh   bool                   `json:"autoRefresh"`
   475  	RefreshWithin time.Duration          `json:"refreshWithin"`
   476  }
   477  
   478  func validateOAuthConfig(oac OAuthConfig) error {
   479  	if oac.IsZero() {
   480  		return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
   481  	}
   482  	return nil
   483  }
   484  
   485  // NewServicePrincipalTokenWithSecret create a ServicePrincipalToken using the supplied ServicePrincipalSecret implementation.
   486  func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   487  	if err := validateOAuthConfig(oauthConfig); err != nil {
   488  		return nil, err
   489  	}
   490  	if err := validateStringParam(id, "id"); err != nil {
   491  		return nil, err
   492  	}
   493  	if err := validateStringParam(resource, "resource"); err != nil {
   494  		return nil, err
   495  	}
   496  	if secret == nil {
   497  		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
   498  	}
   499  	spt := &ServicePrincipalToken{
   500  		inner: servicePrincipalToken{
   501  			Token:         newToken(),
   502  			OauthConfig:   oauthConfig,
   503  			Secret:        secret,
   504  			ClientID:      id,
   505  			Resource:      resource,
   506  			AutoRefresh:   true,
   507  			RefreshWithin: defaultRefresh,
   508  		},
   509  		refreshLock:      &sync.RWMutex{},
   510  		sender:           sender(),
   511  		refreshCallbacks: callbacks,
   512  	}
   513  	return spt, nil
   514  }
   515  
   516  // NewServicePrincipalTokenFromManualToken creates a ServicePrincipalToken using the supplied token
   517  func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   518  	if err := validateOAuthConfig(oauthConfig); err != nil {
   519  		return nil, err
   520  	}
   521  	if err := validateStringParam(clientID, "clientID"); err != nil {
   522  		return nil, err
   523  	}
   524  	if err := validateStringParam(resource, "resource"); err != nil {
   525  		return nil, err
   526  	}
   527  	if token.IsZero() {
   528  		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
   529  	}
   530  	spt, err := NewServicePrincipalTokenWithSecret(
   531  		oauthConfig,
   532  		clientID,
   533  		resource,
   534  		&ServicePrincipalNoSecret{},
   535  		callbacks...)
   536  	if err != nil {
   537  		return nil, err
   538  	}
   539  
   540  	spt.inner.Token = token
   541  
   542  	return spt, nil
   543  }
   544  
   545  // NewServicePrincipalTokenFromManualTokenSecret creates a ServicePrincipalToken using the supplied token and secret
   546  func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   547  	if err := validateOAuthConfig(oauthConfig); err != nil {
   548  		return nil, err
   549  	}
   550  	if err := validateStringParam(clientID, "clientID"); err != nil {
   551  		return nil, err
   552  	}
   553  	if err := validateStringParam(resource, "resource"); err != nil {
   554  		return nil, err
   555  	}
   556  	if secret == nil {
   557  		return nil, fmt.Errorf("parameter 'secret' cannot be nil")
   558  	}
   559  	if token.IsZero() {
   560  		return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
   561  	}
   562  	spt, err := NewServicePrincipalTokenWithSecret(
   563  		oauthConfig,
   564  		clientID,
   565  		resource,
   566  		secret,
   567  		callbacks...)
   568  	if err != nil {
   569  		return nil, err
   570  	}
   571  
   572  	spt.inner.Token = token
   573  
   574  	return spt, nil
   575  }
   576  
   577  // NewServicePrincipalToken creates a ServicePrincipalToken from the supplied Service Principal
   578  // credentials scoped to the named resource.
   579  func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   580  	if err := validateOAuthConfig(oauthConfig); err != nil {
   581  		return nil, err
   582  	}
   583  	if err := validateStringParam(clientID, "clientID"); err != nil {
   584  		return nil, err
   585  	}
   586  	if err := validateStringParam(secret, "secret"); err != nil {
   587  		return nil, err
   588  	}
   589  	if err := validateStringParam(resource, "resource"); err != nil {
   590  		return nil, err
   591  	}
   592  	return NewServicePrincipalTokenWithSecret(
   593  		oauthConfig,
   594  		clientID,
   595  		resource,
   596  		&ServicePrincipalTokenSecret{
   597  			ClientSecret: secret,
   598  		},
   599  		callbacks...,
   600  	)
   601  }
   602  
   603  // NewServicePrincipalTokenFromCertificate creates a ServicePrincipalToken from the supplied pkcs12 bytes.
   604  func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   605  	if err := validateOAuthConfig(oauthConfig); err != nil {
   606  		return nil, err
   607  	}
   608  	if err := validateStringParam(clientID, "clientID"); err != nil {
   609  		return nil, err
   610  	}
   611  	if err := validateStringParam(resource, "resource"); err != nil {
   612  		return nil, err
   613  	}
   614  	if certificate == nil {
   615  		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
   616  	}
   617  	if privateKey == nil {
   618  		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
   619  	}
   620  	return NewServicePrincipalTokenWithSecret(
   621  		oauthConfig,
   622  		clientID,
   623  		resource,
   624  		&ServicePrincipalCertificateSecret{
   625  			PrivateKey:  privateKey,
   626  			Certificate: certificate,
   627  		},
   628  		callbacks...,
   629  	)
   630  }
   631  
   632  // NewServicePrincipalTokenFromUsernamePassword creates a ServicePrincipalToken from the username and password.
   633  func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   634  	if err := validateOAuthConfig(oauthConfig); err != nil {
   635  		return nil, err
   636  	}
   637  	if err := validateStringParam(clientID, "clientID"); err != nil {
   638  		return nil, err
   639  	}
   640  	if err := validateStringParam(username, "username"); err != nil {
   641  		return nil, err
   642  	}
   643  	if err := validateStringParam(password, "password"); err != nil {
   644  		return nil, err
   645  	}
   646  	if err := validateStringParam(resource, "resource"); err != nil {
   647  		return nil, err
   648  	}
   649  	return NewServicePrincipalTokenWithSecret(
   650  		oauthConfig,
   651  		clientID,
   652  		resource,
   653  		&ServicePrincipalUsernamePasswordSecret{
   654  			Username: username,
   655  			Password: password,
   656  		},
   657  		callbacks...,
   658  	)
   659  }
   660  
   661  // NewServicePrincipalTokenFromAuthorizationCode creates a ServicePrincipalToken from the
   662  func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   663  
   664  	if err := validateOAuthConfig(oauthConfig); err != nil {
   665  		return nil, err
   666  	}
   667  	if err := validateStringParam(clientID, "clientID"); err != nil {
   668  		return nil, err
   669  	}
   670  	if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
   671  		return nil, err
   672  	}
   673  	if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
   674  		return nil, err
   675  	}
   676  	if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
   677  		return nil, err
   678  	}
   679  	if err := validateStringParam(resource, "resource"); err != nil {
   680  		return nil, err
   681  	}
   682  
   683  	return NewServicePrincipalTokenWithSecret(
   684  		oauthConfig,
   685  		clientID,
   686  		resource,
   687  		&ServicePrincipalAuthorizationCodeSecret{
   688  			ClientSecret:      clientSecret,
   689  			AuthorizationCode: authorizationCode,
   690  			RedirectURI:       redirectURI,
   691  		},
   692  		callbacks...,
   693  	)
   694  }
   695  
   696  // NewServicePrincipalTokenFromFederatedToken creates a ServicePrincipalToken from the supplied federated OIDC JWT.
   697  //
   698  // Deprecated: Use NewServicePrincipalTokenFromFederatedTokenWithCallback to refresh jwt dynamically.
   699  func NewServicePrincipalTokenFromFederatedToken(oauthConfig OAuthConfig, clientID string, jwt string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   700  	if err := validateOAuthConfig(oauthConfig); err != nil {
   701  		return nil, err
   702  	}
   703  	if err := validateStringParam(clientID, "clientID"); err != nil {
   704  		return nil, err
   705  	}
   706  	if err := validateStringParam(resource, "resource"); err != nil {
   707  		return nil, err
   708  	}
   709  	if jwt == "" {
   710  		return nil, fmt.Errorf("parameter 'jwt' cannot be empty")
   711  	}
   712  	return NewServicePrincipalTokenFromFederatedTokenCallback(
   713  		oauthConfig,
   714  		clientID,
   715  		func() (string, error) {
   716  			return jwt, nil
   717  		},
   718  		resource,
   719  		callbacks...,
   720  	)
   721  }
   722  
   723  // NewServicePrincipalTokenFromFederatedTokenCallback creates a ServicePrincipalToken from the supplied federated OIDC JWTCallback.
   724  func NewServicePrincipalTokenFromFederatedTokenCallback(oauthConfig OAuthConfig, clientID string, jwtCallback JWTCallback, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   725  	if err := validateOAuthConfig(oauthConfig); err != nil {
   726  		return nil, err
   727  	}
   728  	if err := validateStringParam(clientID, "clientID"); err != nil {
   729  		return nil, err
   730  	}
   731  	if err := validateStringParam(resource, "resource"); err != nil {
   732  		return nil, err
   733  	}
   734  	if jwtCallback == nil {
   735  		return nil, fmt.Errorf("parameter 'jwtCallback' cannot be empty")
   736  	}
   737  	return NewServicePrincipalTokenWithSecret(
   738  		oauthConfig,
   739  		clientID,
   740  		resource,
   741  		&ServicePrincipalFederatedSecret{
   742  			jwtCallback: jwtCallback,
   743  		},
   744  		callbacks...,
   745  	)
   746  }
   747  
   748  type msiType int
   749  
   750  const (
   751  	msiTypeUnavailable msiType = iota
   752  	msiTypeAppServiceV20170901
   753  	msiTypeCloudShell
   754  	msiTypeIMDS
   755  )
   756  
   757  func (m msiType) String() string {
   758  	switch m {
   759  	case msiTypeAppServiceV20170901:
   760  		return "AppServiceV20170901"
   761  	case msiTypeCloudShell:
   762  		return "CloudShell"
   763  	case msiTypeIMDS:
   764  		return "IMDS"
   765  	default:
   766  		return fmt.Sprintf("unhandled MSI type %d", m)
   767  	}
   768  }
   769  
   770  // returns the MSI type and endpoint, or an error
   771  func getMSIType() (msiType, string, error) {
   772  	if endpointEnvVar := os.Getenv(msiEndpointEnv); endpointEnvVar != "" {
   773  		// if the env var MSI_ENDPOINT is set
   774  		if secretEnvVar := os.Getenv(msiSecretEnv); secretEnvVar != "" {
   775  			// if BOTH the env vars MSI_ENDPOINT and MSI_SECRET are set the msiType is AppService
   776  			return msiTypeAppServiceV20170901, endpointEnvVar, nil
   777  		}
   778  		// if ONLY the env var MSI_ENDPOINT is set the msiType is CloudShell
   779  		return msiTypeCloudShell, endpointEnvVar, nil
   780  	}
   781  	// if MSI_ENDPOINT is NOT set assume the msiType is IMDS
   782  	return msiTypeIMDS, msiEndpoint, nil
   783  }
   784  
   785  // GetMSIVMEndpoint gets the MSI endpoint on Virtual Machines.
   786  // NOTE: this always returns the IMDS endpoint, it does not work for app services or cloud shell.
   787  // Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
   788  func GetMSIVMEndpoint() (string, error) {
   789  	return msiEndpoint, nil
   790  }
   791  
   792  // GetMSIAppServiceEndpoint get the MSI endpoint for App Service and Functions.
   793  // It will return an error when not running in an app service/functions environment.
   794  // Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
   795  func GetMSIAppServiceEndpoint() (string, error) {
   796  	msiType, endpoint, err := getMSIType()
   797  	if err != nil {
   798  		return "", err
   799  	}
   800  	switch msiType {
   801  	case msiTypeAppServiceV20170901:
   802  		return endpoint, nil
   803  	default:
   804  		return "", fmt.Errorf("%s is not app service environment", msiType)
   805  	}
   806  }
   807  
   808  // GetMSIEndpoint get the appropriate MSI endpoint depending on the runtime environment
   809  // Deprecated: NewServicePrincipalTokenFromMSI() and variants will automatically detect the endpoint.
   810  func GetMSIEndpoint() (string, error) {
   811  	_, endpoint, err := getMSIType()
   812  	return endpoint, err
   813  }
   814  
   815  // NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
   816  // It will use the system assigned identity when creating the token.
   817  // msiEndpoint - empty string, or pass a non-empty string to override the default value.
   818  // Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
   819  func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   820  	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", "", callbacks...)
   821  }
   822  
   823  // NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
   824  // It will use the clientID of specified user assigned identity when creating the token.
   825  // msiEndpoint - empty string, or pass a non-empty string to override the default value.
   826  // Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
   827  func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   828  	if err := validateStringParam(userAssignedID, "userAssignedID"); err != nil {
   829  		return nil, err
   830  	}
   831  	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, "", callbacks...)
   832  }
   833  
   834  // NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension.
   835  // It will use the azure resource id of user assigned identity when creating the token.
   836  // msiEndpoint - empty string, or pass a non-empty string to override the default value.
   837  // Deprecated: use NewServicePrincipalTokenFromManagedIdentity() instead.
   838  func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   839  	if err := validateStringParam(identityResourceID, "identityResourceID"); err != nil {
   840  		return nil, err
   841  	}
   842  	return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", identityResourceID, callbacks...)
   843  }
   844  
   845  // ManagedIdentityOptions contains optional values for configuring managed identity authentication.
   846  type ManagedIdentityOptions struct {
   847  	// ClientID is the user-assigned identity to use during authentication.
   848  	// It is mutually exclusive with IdentityResourceID.
   849  	ClientID string
   850  
   851  	// IdentityResourceID is the resource ID of the user-assigned identity to use during authentication.
   852  	// It is mutually exclusive with ClientID.
   853  	IdentityResourceID string
   854  }
   855  
   856  // NewServicePrincipalTokenFromManagedIdentity creates a ServicePrincipalToken using a managed identity.
   857  // It supports the following managed identity environments.
   858  // - App Service Environment (API version 2017-09-01 only)
   859  // - Cloud shell
   860  // - IMDS with a system or user assigned identity
   861  func NewServicePrincipalTokenFromManagedIdentity(resource string, options *ManagedIdentityOptions, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   862  	if options == nil {
   863  		options = &ManagedIdentityOptions{}
   864  	}
   865  	return newServicePrincipalTokenFromMSI("", resource, options.ClientID, options.IdentityResourceID, callbacks...)
   866  }
   867  
   868  func newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
   869  	if err := validateStringParam(resource, "resource"); err != nil {
   870  		return nil, err
   871  	}
   872  	if userAssignedID != "" && identityResourceID != "" {
   873  		return nil, errors.New("cannot specify userAssignedID and identityResourceID")
   874  	}
   875  	msiType, endpoint, err := getMSIType()
   876  	if err != nil {
   877  		logger.Instance.Writef(logger.LogError, "Error determining managed identity environment: %v\n", err)
   878  		return nil, err
   879  	}
   880  	logger.Instance.Writef(logger.LogInfo, "Managed identity environment is %s, endpoint is %s\n", msiType, endpoint)
   881  	if msiEndpoint != "" {
   882  		endpoint = msiEndpoint
   883  		logger.Instance.Writef(logger.LogInfo, "Managed identity custom endpoint is %s\n", endpoint)
   884  	}
   885  	msiEndpointURL, err := url.Parse(endpoint)
   886  	if err != nil {
   887  		return nil, err
   888  	}
   889  	// cloud shell sends its data in the request body
   890  	if msiType != msiTypeCloudShell {
   891  		v := url.Values{}
   892  		v.Set("resource", resource)
   893  		clientIDParam := "client_id"
   894  		switch msiType {
   895  		case msiTypeAppServiceV20170901:
   896  			clientIDParam = "clientid"
   897  			v.Set("api-version", appServiceAPIVersion2017)
   898  			break
   899  		case msiTypeIMDS:
   900  			v.Set("api-version", msiAPIVersion)
   901  		}
   902  		if userAssignedID != "" {
   903  			v.Set(clientIDParam, userAssignedID)
   904  		} else if identityResourceID != "" {
   905  			v.Set("mi_res_id", identityResourceID)
   906  		}
   907  		msiEndpointURL.RawQuery = v.Encode()
   908  	}
   909  
   910  	spt := &ServicePrincipalToken{
   911  		inner: servicePrincipalToken{
   912  			Token: newToken(),
   913  			OauthConfig: OAuthConfig{
   914  				TokenEndpoint: *msiEndpointURL,
   915  			},
   916  			Secret: &ServicePrincipalMSISecret{
   917  				msiType:          msiType,
   918  				clientResourceID: identityResourceID,
   919  			},
   920  			Resource:      resource,
   921  			AutoRefresh:   true,
   922  			RefreshWithin: defaultRefresh,
   923  			ClientID:      userAssignedID,
   924  		},
   925  		refreshLock:           &sync.RWMutex{},
   926  		sender:                sender(),
   927  		refreshCallbacks:      callbacks,
   928  		MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
   929  	}
   930  
   931  	return spt, nil
   932  }
   933  
   934  // internal type that implements TokenRefreshError
   935  type tokenRefreshError struct {
   936  	message string
   937  	resp    *http.Response
   938  }
   939  
   940  // Error implements the error interface which is part of the TokenRefreshError interface.
   941  func (tre tokenRefreshError) Error() string {
   942  	return tre.message
   943  }
   944  
   945  // Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
   946  func (tre tokenRefreshError) Response() *http.Response {
   947  	return tre.resp
   948  }
   949  
   950  func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
   951  	return tokenRefreshError{message: message, resp: resp}
   952  }
   953  
   954  // EnsureFresh will refresh the token if it will expire within the refresh window (as set by
   955  // RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
   956  func (spt *ServicePrincipalToken) EnsureFresh() error {
   957  	return spt.EnsureFreshWithContext(context.Background())
   958  }
   959  
   960  // EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
   961  // RefreshWithin) and autoRefresh flag is on.  This method is safe for concurrent use.
   962  func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
   963  	// must take the read lock when initially checking the token's expiration
   964  	if spt.inner.AutoRefresh && spt.Token().WillExpireIn(spt.inner.RefreshWithin) {
   965  		// take the write lock then check again to see if the token was already refreshed
   966  		spt.refreshLock.Lock()
   967  		defer spt.refreshLock.Unlock()
   968  		if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
   969  			return spt.refreshInternal(ctx, spt.inner.Resource)
   970  		}
   971  	}
   972  	return nil
   973  }
   974  
   975  // InvokeRefreshCallbacks calls any TokenRefreshCallbacks that were added to the SPT during initialization
   976  func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
   977  	if spt.refreshCallbacks != nil {
   978  		for _, callback := range spt.refreshCallbacks {
   979  			err := callback(spt.inner.Token)
   980  			if err != nil {
   981  				return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
   982  			}
   983  		}
   984  	}
   985  	return nil
   986  }
   987  
   988  // Refresh obtains a fresh token for the Service Principal.
   989  // This method is safe for concurrent use.
   990  func (spt *ServicePrincipalToken) Refresh() error {
   991  	return spt.RefreshWithContext(context.Background())
   992  }
   993  
   994  // RefreshWithContext obtains a fresh token for the Service Principal.
   995  // This method is safe for concurrent use.
   996  func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
   997  	spt.refreshLock.Lock()
   998  	defer spt.refreshLock.Unlock()
   999  	return spt.refreshInternal(ctx, spt.inner.Resource)
  1000  }
  1001  
  1002  // RefreshExchange refreshes the token, but for a different resource.
  1003  // This method is safe for concurrent use.
  1004  func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
  1005  	return spt.RefreshExchangeWithContext(context.Background(), resource)
  1006  }
  1007  
  1008  // RefreshExchangeWithContext refreshes the token, but for a different resource.
  1009  // This method is safe for concurrent use.
  1010  func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
  1011  	spt.refreshLock.Lock()
  1012  	defer spt.refreshLock.Unlock()
  1013  	return spt.refreshInternal(ctx, resource)
  1014  }
  1015  
  1016  func (spt *ServicePrincipalToken) getGrantType() string {
  1017  	switch spt.inner.Secret.(type) {
  1018  	case *ServicePrincipalUsernamePasswordSecret:
  1019  		return OAuthGrantTypeUserPass
  1020  	case *ServicePrincipalAuthorizationCodeSecret:
  1021  		return OAuthGrantTypeAuthorizationCode
  1022  	default:
  1023  		return OAuthGrantTypeClientCredentials
  1024  	}
  1025  }
  1026  
  1027  func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
  1028  	if spt.customRefreshFunc != nil {
  1029  		token, err := spt.customRefreshFunc(ctx, resource)
  1030  		if err != nil {
  1031  			return err
  1032  		}
  1033  		spt.inner.Token = *token
  1034  		return spt.InvokeRefreshCallbacks(spt.inner.Token)
  1035  	}
  1036  	req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
  1037  	if err != nil {
  1038  		return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
  1039  	}
  1040  	req.Header.Add("User-Agent", UserAgent())
  1041  	req = req.WithContext(ctx)
  1042  	var resp *http.Response
  1043  	authBodyFilter := func(b []byte) []byte {
  1044  		if logger.Level() != logger.LogAuth {
  1045  			return []byte("**REDACTED** authentication body")
  1046  		}
  1047  		return b
  1048  	}
  1049  	if msiSecret, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
  1050  		switch msiSecret.msiType {
  1051  		case msiTypeAppServiceV20170901:
  1052  			req.Method = http.MethodGet
  1053  			req.Header.Set("secret", os.Getenv(msiSecretEnv))
  1054  			break
  1055  		case msiTypeCloudShell:
  1056  			req.Header.Set("Metadata", "true")
  1057  			data := url.Values{}
  1058  			data.Set("resource", spt.inner.Resource)
  1059  			if spt.inner.ClientID != "" {
  1060  				data.Set("client_id", spt.inner.ClientID)
  1061  			} else if msiSecret.clientResourceID != "" {
  1062  				data.Set("msi_res_id", msiSecret.clientResourceID)
  1063  			}
  1064  			req.Body = ioutil.NopCloser(strings.NewReader(data.Encode()))
  1065  			req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  1066  			break
  1067  		case msiTypeIMDS:
  1068  			req.Method = http.MethodGet
  1069  			req.Header.Set("Metadata", "true")
  1070  			break
  1071  		}
  1072  		logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter})
  1073  		resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
  1074  	} else {
  1075  		v := url.Values{}
  1076  		v.Set("client_id", spt.inner.ClientID)
  1077  		v.Set("resource", resource)
  1078  
  1079  		if spt.inner.Token.RefreshToken != "" {
  1080  			v.Set("grant_type", OAuthGrantTypeRefreshToken)
  1081  			v.Set("refresh_token", spt.inner.Token.RefreshToken)
  1082  			// web apps must specify client_secret when refreshing tokens
  1083  			// see https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#refreshing-the-access-tokens
  1084  			if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
  1085  				err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
  1086  				if err != nil {
  1087  					return err
  1088  				}
  1089  			}
  1090  		} else {
  1091  			v.Set("grant_type", spt.getGrantType())
  1092  			err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
  1093  			if err != nil {
  1094  				return err
  1095  			}
  1096  		}
  1097  
  1098  		s := v.Encode()
  1099  		body := ioutil.NopCloser(strings.NewReader(s))
  1100  		req.ContentLength = int64(len(s))
  1101  		req.Header.Set(contentType, mimeTypeFormPost)
  1102  		req.Body = body
  1103  		logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter})
  1104  		resp, err = spt.sender.Do(req)
  1105  	}
  1106  
  1107  	// don't return a TokenRefreshError here; this will allow retry logic to apply
  1108  	if err != nil {
  1109  		return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
  1110  	} else if resp == nil {
  1111  		return fmt.Errorf("adal: received nil response and error")
  1112  	}
  1113  
  1114  	logger.Instance.WriteResponse(resp, logger.Filter{Body: authBodyFilter})
  1115  	defer resp.Body.Close()
  1116  	rb, err := ioutil.ReadAll(resp.Body)
  1117  
  1118  	if resp.StatusCode != http.StatusOK {
  1119  		if err != nil {
  1120  			return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v Endpoint %s", resp.StatusCode, err, req.URL.String()), resp)
  1121  		}
  1122  		return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s Endpoint %s", resp.StatusCode, string(rb), req.URL.String()), resp)
  1123  	}
  1124  
  1125  	// for the following error cases don't return a TokenRefreshError.  the operation succeeded
  1126  	// but some transient failure happened during deserialization.  by returning a generic error
  1127  	// the retry logic will kick in (we don't retry on TokenRefreshError).
  1128  
  1129  	if err != nil {
  1130  		return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
  1131  	}
  1132  	if len(strings.Trim(string(rb), " ")) == 0 {
  1133  		return fmt.Errorf("adal: Empty service principal token received during refresh")
  1134  	}
  1135  	token := struct {
  1136  		AccessToken  string `json:"access_token"`
  1137  		RefreshToken string `json:"refresh_token"`
  1138  
  1139  		// AAD returns expires_in as a string, ADFS returns it as an int
  1140  		ExpiresIn json.Number `json:"expires_in"`
  1141  		// expires_on can be in three formats, a UTC time stamp, or the number of seconds as a string *or* int.
  1142  		ExpiresOn interface{} `json:"expires_on"`
  1143  		NotBefore json.Number `json:"not_before"`
  1144  
  1145  		Resource string `json:"resource"`
  1146  		Type     string `json:"token_type"`
  1147  	}{}
  1148  	// return a TokenRefreshError in the follow error cases as the token is in an unexpected format
  1149  	err = json.Unmarshal(rb, &token)
  1150  	if err != nil {
  1151  		return newTokenRefreshError(fmt.Sprintf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)), resp)
  1152  	}
  1153  	expiresOn := json.Number("")
  1154  	// ADFS doesn't include the expires_on field
  1155  	if token.ExpiresOn != nil {
  1156  		if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil {
  1157  			return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp)
  1158  		}
  1159  	}
  1160  	spt.inner.Token.AccessToken = token.AccessToken
  1161  	spt.inner.Token.RefreshToken = token.RefreshToken
  1162  	spt.inner.Token.ExpiresIn = token.ExpiresIn
  1163  	spt.inner.Token.ExpiresOn = expiresOn
  1164  	spt.inner.Token.NotBefore = token.NotBefore
  1165  	spt.inner.Token.Resource = token.Resource
  1166  	spt.inner.Token.Type = token.Type
  1167  
  1168  	return spt.InvokeRefreshCallbacks(spt.inner.Token)
  1169  }
  1170  
  1171  // converts expires_on to the number of seconds
  1172  func parseExpiresOn(s interface{}) (json.Number, error) {
  1173  	// the JSON unmarshaler treats JSON numbers unmarshaled into an interface{} as float64
  1174  	asFloat64, ok := s.(float64)
  1175  	if ok {
  1176  		// this is the number of seconds as int case
  1177  		return json.Number(strconv.FormatInt(int64(asFloat64), 10)), nil
  1178  	}
  1179  	asStr, ok := s.(string)
  1180  	if !ok {
  1181  		return "", fmt.Errorf("unexpected expires_on type %T", s)
  1182  	}
  1183  	// convert the expiration date to the number of seconds from the unix epoch
  1184  	timeToDuration := func(t time.Time) json.Number {
  1185  		return json.Number(strconv.FormatInt(t.UTC().Unix(), 10))
  1186  	}
  1187  	if _, err := json.Number(asStr).Int64(); err == nil {
  1188  		// this is the number of seconds case, no conversion required
  1189  		return json.Number(asStr), nil
  1190  	} else if eo, err := time.Parse(expiresOnDateFormatPM, asStr); err == nil {
  1191  		return timeToDuration(eo), nil
  1192  	} else if eo, err := time.Parse(expiresOnDateFormat, asStr); err == nil {
  1193  		return timeToDuration(eo), nil
  1194  	} else {
  1195  		// unknown format
  1196  		return json.Number(""), err
  1197  	}
  1198  }
  1199  
  1200  // retry logic specific to retrieving a token from the IMDS endpoint
  1201  func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
  1202  	// copied from client.go due to circular dependency
  1203  	retries := []int{
  1204  		http.StatusRequestTimeout,      // 408
  1205  		http.StatusTooManyRequests,     // 429
  1206  		http.StatusInternalServerError, // 500
  1207  		http.StatusBadGateway,          // 502
  1208  		http.StatusServiceUnavailable,  // 503
  1209  		http.StatusGatewayTimeout,      // 504
  1210  	}
  1211  	// extra retry status codes specific to IMDS
  1212  	retries = append(retries,
  1213  		http.StatusNotFound,
  1214  		http.StatusGone,
  1215  		// all remaining 5xx
  1216  		http.StatusNotImplemented,
  1217  		http.StatusHTTPVersionNotSupported,
  1218  		http.StatusVariantAlsoNegotiates,
  1219  		http.StatusInsufficientStorage,
  1220  		http.StatusLoopDetected,
  1221  		http.StatusNotExtended,
  1222  		http.StatusNetworkAuthenticationRequired)
  1223  
  1224  	// see https://docs.microsoft.com/en-us/azure/active-directory/managed-service-identity/how-to-use-vm-token#retry-guidance
  1225  
  1226  	const maxDelay time.Duration = 60 * time.Second
  1227  
  1228  	attempt := 0
  1229  	delay := time.Duration(0)
  1230  
  1231  	// maxAttempts is user-specified, ensure that its value is greater than zero else no request will be made
  1232  	if maxAttempts < 1 {
  1233  		maxAttempts = defaultMaxMSIRefreshAttempts
  1234  	}
  1235  
  1236  	for attempt < maxAttempts {
  1237  		if resp != nil && resp.Body != nil {
  1238  			io.Copy(ioutil.Discard, resp.Body)
  1239  			resp.Body.Close()
  1240  		}
  1241  		resp, err = sender.Do(req)
  1242  		// we want to retry if err is not nil or the status code is in the list of retry codes
  1243  		if err == nil && !responseHasStatusCode(resp, retries...) {
  1244  			return
  1245  		}
  1246  
  1247  		// perform exponential backoff with a cap.
  1248  		// must increment attempt before calculating delay.
  1249  		attempt++
  1250  		// the base value of 2 is the "delta backoff" as specified in the guidance doc
  1251  		delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
  1252  		if delay > maxDelay {
  1253  			delay = maxDelay
  1254  		}
  1255  
  1256  		select {
  1257  		case <-time.After(delay):
  1258  			// intentionally left blank
  1259  		case <-req.Context().Done():
  1260  			err = req.Context().Err()
  1261  			return
  1262  		}
  1263  	}
  1264  	return
  1265  }
  1266  
  1267  func responseHasStatusCode(resp *http.Response, codes ...int) bool {
  1268  	if resp != nil {
  1269  		for _, i := range codes {
  1270  			if i == resp.StatusCode {
  1271  				return true
  1272  			}
  1273  		}
  1274  	}
  1275  	return false
  1276  }
  1277  
  1278  // SetAutoRefresh enables or disables automatic refreshing of stale tokens.
  1279  func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
  1280  	spt.inner.AutoRefresh = autoRefresh
  1281  }
  1282  
  1283  // SetRefreshWithin sets the interval within which if the token will expire, EnsureFresh will
  1284  // refresh the token.
  1285  func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
  1286  	spt.inner.RefreshWithin = d
  1287  	return
  1288  }
  1289  
  1290  // SetSender sets the http.Client used when obtaining the Service Principal token. An
  1291  // undecorated http.Client is used by default.
  1292  func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
  1293  
  1294  // OAuthToken implements the OAuthTokenProvider interface.  It returns the current access token.
  1295  func (spt *ServicePrincipalToken) OAuthToken() string {
  1296  	spt.refreshLock.RLock()
  1297  	defer spt.refreshLock.RUnlock()
  1298  	return spt.inner.Token.OAuthToken()
  1299  }
  1300  
  1301  // Token returns a copy of the current token.
  1302  func (spt *ServicePrincipalToken) Token() Token {
  1303  	spt.refreshLock.RLock()
  1304  	defer spt.refreshLock.RUnlock()
  1305  	return spt.inner.Token
  1306  }
  1307  
  1308  // MultiTenantServicePrincipalToken contains tokens for multi-tenant authorization.
  1309  type MultiTenantServicePrincipalToken struct {
  1310  	PrimaryToken    *ServicePrincipalToken
  1311  	AuxiliaryTokens []*ServicePrincipalToken
  1312  }
  1313  
  1314  // PrimaryOAuthToken returns the primary authorization token.
  1315  func (mt *MultiTenantServicePrincipalToken) PrimaryOAuthToken() string {
  1316  	return mt.PrimaryToken.OAuthToken()
  1317  }
  1318  
  1319  // AuxiliaryOAuthTokens returns one to three auxiliary authorization tokens.
  1320  func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string {
  1321  	tokens := make([]string, len(mt.AuxiliaryTokens))
  1322  	for i := range mt.AuxiliaryTokens {
  1323  		tokens[i] = mt.AuxiliaryTokens[i].OAuthToken()
  1324  	}
  1325  	return tokens
  1326  }
  1327  
  1328  // NewMultiTenantServicePrincipalToken creates a new MultiTenantServicePrincipalToken with the specified credentials and resource.
  1329  func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) {
  1330  	if err := validateStringParam(clientID, "clientID"); err != nil {
  1331  		return nil, err
  1332  	}
  1333  	if err := validateStringParam(secret, "secret"); err != nil {
  1334  		return nil, err
  1335  	}
  1336  	if err := validateStringParam(resource, "resource"); err != nil {
  1337  		return nil, err
  1338  	}
  1339  	auxTenants := multiTenantCfg.AuxiliaryTenants()
  1340  	m := MultiTenantServicePrincipalToken{
  1341  		AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
  1342  	}
  1343  	primary, err := NewServicePrincipalToken(*multiTenantCfg.PrimaryTenant(), clientID, secret, resource)
  1344  	if err != nil {
  1345  		return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
  1346  	}
  1347  	m.PrimaryToken = primary
  1348  	for i := range auxTenants {
  1349  		aux, err := NewServicePrincipalToken(*auxTenants[i], clientID, secret, resource)
  1350  		if err != nil {
  1351  			return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
  1352  		}
  1353  		m.AuxiliaryTokens[i] = aux
  1354  	}
  1355  	return &m, nil
  1356  }
  1357  
  1358  // NewMultiTenantServicePrincipalTokenFromCertificate creates a new MultiTenantServicePrincipalToken with the specified certificate credentials and resource.
  1359  func NewMultiTenantServicePrincipalTokenFromCertificate(multiTenantCfg MultiTenantOAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string) (*MultiTenantServicePrincipalToken, error) {
  1360  	if err := validateStringParam(clientID, "clientID"); err != nil {
  1361  		return nil, err
  1362  	}
  1363  	if err := validateStringParam(resource, "resource"); err != nil {
  1364  		return nil, err
  1365  	}
  1366  	if certificate == nil {
  1367  		return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
  1368  	}
  1369  	if privateKey == nil {
  1370  		return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
  1371  	}
  1372  	auxTenants := multiTenantCfg.AuxiliaryTenants()
  1373  	m := MultiTenantServicePrincipalToken{
  1374  		AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
  1375  	}
  1376  	primary, err := NewServicePrincipalTokenWithSecret(
  1377  		*multiTenantCfg.PrimaryTenant(),
  1378  		clientID,
  1379  		resource,
  1380  		&ServicePrincipalCertificateSecret{
  1381  			PrivateKey:  privateKey,
  1382  			Certificate: certificate,
  1383  		},
  1384  	)
  1385  	if err != nil {
  1386  		return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
  1387  	}
  1388  	m.PrimaryToken = primary
  1389  	for i := range auxTenants {
  1390  		aux, err := NewServicePrincipalTokenWithSecret(
  1391  			*auxTenants[i],
  1392  			clientID,
  1393  			resource,
  1394  			&ServicePrincipalCertificateSecret{
  1395  				PrivateKey:  privateKey,
  1396  				Certificate: certificate,
  1397  			},
  1398  		)
  1399  		if err != nil {
  1400  			return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
  1401  		}
  1402  		m.AuxiliaryTokens[i] = aux
  1403  	}
  1404  	return &m, nil
  1405  }
  1406  
  1407  // MSIAvailable returns true if the MSI endpoint is available for authentication.
  1408  func MSIAvailable(ctx context.Context, s Sender) bool {
  1409  	msiType, _, err := getMSIType()
  1410  
  1411  	if err != nil {
  1412  		return false
  1413  	}
  1414  
  1415  	if msiType != msiTypeIMDS {
  1416  		return true
  1417  	}
  1418  
  1419  	if s == nil {
  1420  		s = sender()
  1421  	}
  1422  
  1423  	resp, err := getMSIEndpoint(ctx, s)
  1424  
  1425  	if err == nil {
  1426  		resp.Body.Close()
  1427  	}
  1428  
  1429  	return err == nil
  1430  }
  1431  

View as plain text