...

Source file src/github.com/Azure/go-autorest/autorest/adal/token_test.go

Documentation: github.com/Azure/go-autorest/autorest/adal

     1  package adal
     2  
     3  // Copyright 2017 Microsoft Corporation
     4  //
     5  //  Licensed under the Apache License, Version 2.0 (the "License");
     6  //  you may not use this file except in compliance with the License.
     7  //  You may obtain a copy of the License at
     8  //
     9  //      http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  //  Unless required by applicable law or agreed to in writing, software
    12  //  distributed under the License is distributed on an "AS IS" BASIS,
    13  //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  //  See the License for the specific language governing permissions and
    15  //  limitations under the License.
    16  
    17  import (
    18  	"context"
    19  	"crypto/rand"
    20  	"crypto/rsa"
    21  	"crypto/x509"
    22  	"crypto/x509/pkix"
    23  	"encoding/json"
    24  	"fmt"
    25  	"io/ioutil"
    26  	"math/big"
    27  	"net/http"
    28  	"net/http/httptest"
    29  	"net/url"
    30  	"os"
    31  	"path/filepath"
    32  	"reflect"
    33  	"strconv"
    34  	"strings"
    35  	"sync"
    36  	"testing"
    37  	"time"
    38  
    39  	"github.com/Azure/go-autorest/autorest/date"
    40  	"github.com/Azure/go-autorest/autorest/mocks"
    41  	jwt "github.com/golang-jwt/jwt/v4"
    42  	"github.com/stretchr/testify/assert"
    43  )
    44  
    45  const (
    46  	defaultFormData       = "client_id=id&client_secret=secret&grant_type=client_credentials&resource=resource"
    47  	defaultManualFormData = "client_id=id&grant_type=refresh_token&refresh_token=refreshtoken&resource=resource"
    48  )
    49  
    50  func TestTokenExpires(t *testing.T) {
    51  	tt := time.Now().Add(5 * time.Second)
    52  	tk := newTokenExpiresAt(tt)
    53  
    54  	if tk.Expires().Equal(tt) {
    55  		t.Fatalf("adal: Token#Expires miscalculated expiration time -- received %v, expected %v", tk.Expires(), tt)
    56  	}
    57  }
    58  
    59  func TestTokenIsExpired(t *testing.T) {
    60  	tk := newTokenExpiresAt(time.Now().Add(-5 * time.Second))
    61  
    62  	if !tk.IsExpired() {
    63  		t.Fatalf("adal: Token#IsExpired failed to mark a stale token as expired -- now %v, token expires at %v",
    64  			time.Now().UTC(), tk.Expires())
    65  	}
    66  }
    67  
    68  func TestTokenIsExpiredUninitialized(t *testing.T) {
    69  	tk := &Token{}
    70  
    71  	if !tk.IsExpired() {
    72  		t.Fatalf("adal: An uninitialized Token failed to mark itself as expired (expiration time %v)", tk.Expires())
    73  	}
    74  }
    75  
    76  func TestTokenIsNoExpired(t *testing.T) {
    77  	tk := newTokenExpiresAt(time.Now().Add(1000 * time.Second))
    78  
    79  	if tk.IsExpired() {
    80  		t.Fatalf("adal: Token marked a fresh token as expired -- now %v, token expires at %v", time.Now().UTC(), tk.Expires())
    81  	}
    82  }
    83  
    84  func TestTokenWillExpireIn(t *testing.T) {
    85  	d := 5 * time.Second
    86  	tk := newTokenExpiresIn(d)
    87  
    88  	if !tk.WillExpireIn(d) {
    89  		t.Fatal("adal: Token#WillExpireIn mismeasured expiration time")
    90  	}
    91  }
    92  
    93  func TestParseExpiresOn(t *testing.T) {
    94  	n := time.Now().UTC()
    95  	amPM := "AM"
    96  	if n.Hour() >= 12 {
    97  		amPM = "PM"
    98  	}
    99  	testcases := []struct {
   100  		Name   string
   101  		String string
   102  		Value  int64
   103  	}{
   104  		{
   105  			Name:   "integer",
   106  			String: "3600",
   107  			Value:  3600,
   108  		},
   109  		{
   110  			Name:   "timestamp with AM/PM",
   111  			String: fmt.Sprintf("%d/%d/%d %d:%02d:%02d %s +00:00", n.Month(), n.Day(), n.Year(), n.Hour(), n.Minute(), n.Second(), amPM),
   112  			Value:  n.Unix(),
   113  		},
   114  		{
   115  			Name:   "timestamp without AM/PM",
   116  			String: fmt.Sprintf("%02d/%02d/%02d %02d:%02d:%02d +00:00", n.Month(), n.Day(), n.Year(), n.Hour(), n.Minute(), n.Second()),
   117  			Value:  n.Unix(),
   118  		},
   119  	}
   120  	for _, tc := range testcases {
   121  		t.Run(tc.Name, func(subT *testing.T) {
   122  			jn, err := parseExpiresOn(tc.String)
   123  			if err != nil {
   124  				subT.Error(err)
   125  			}
   126  			i, err := jn.Int64()
   127  			if err != nil {
   128  				subT.Error(err)
   129  			}
   130  			if i != tc.Value {
   131  				subT.Logf("expected %d, got %d", tc.Value, i)
   132  				subT.Fail()
   133  			}
   134  		})
   135  	}
   136  }
   137  
   138  func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
   139  	spt := newServicePrincipalToken()
   140  
   141  	if !spt.inner.AutoRefresh {
   142  		t.Fatal("adal: ServicePrincipalToken did not default to automatic token refreshing")
   143  	}
   144  
   145  	spt.SetAutoRefresh(false)
   146  	if spt.inner.AutoRefresh {
   147  		t.Fatal("adal: ServicePrincipalToken#SetAutoRefresh did not disable automatic token refreshing")
   148  	}
   149  }
   150  
   151  func TestServicePrincipalTokenSetCustomRefreshFunc(t *testing.T) {
   152  	spt := newServicePrincipalToken()
   153  
   154  	var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) {
   155  		return nil, nil
   156  	}
   157  
   158  	if spt.customRefreshFunc != nil {
   159  		t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc had a default custom refresh func when it shouldn't")
   160  	}
   161  
   162  	spt.SetCustomRefreshFunc(refreshFunc)
   163  
   164  	if spt.customRefreshFunc == nil {
   165  		t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc didn't have a refresh func")
   166  	}
   167  }
   168  
   169  func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) {
   170  	spt := newServicePrincipalToken()
   171  
   172  	if spt.inner.RefreshWithin != defaultRefresh {
   173  		t.Fatal("adal: ServicePrincipalToken did not correctly set the default refresh interval")
   174  	}
   175  
   176  	spt.SetRefreshWithin(2 * defaultRefresh)
   177  	if spt.inner.RefreshWithin != 2*defaultRefresh {
   178  		t.Fatal("adal: ServicePrincipalToken#SetRefreshWithin did not set the refresh interval")
   179  	}
   180  }
   181  
   182  func TestServicePrincipalTokenSetSender(t *testing.T) {
   183  	spt := newServicePrincipalToken()
   184  
   185  	c := &http.Client{}
   186  	spt.SetSender(c)
   187  	if !reflect.DeepEqual(c, spt.sender) {
   188  		t.Fatal("adal: ServicePrincipalToken#SetSender did not set the sender")
   189  	}
   190  }
   191  
   192  func TestServicePrincipalTokenRefreshUsesCustomRefreshFunc(t *testing.T) {
   193  	spt := newServicePrincipalToken()
   194  
   195  	called := false
   196  	var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) {
   197  		called = true
   198  		return &Token{}, nil
   199  	}
   200  	spt.SetCustomRefreshFunc(refreshFunc)
   201  	if called {
   202  		t.Fatalf("adal: ServicePrincipalToken#refreshInternal called the refresh function prior to refreshing")
   203  	}
   204  
   205  	spt.refreshInternal(context.Background(), "https://example.com")
   206  
   207  	if !called {
   208  		t.Fatalf("adal: ServicePrincipalToken#refreshInternal didn't call the refresh function")
   209  	}
   210  }
   211  
   212  func TestFederatedTokenRefreshUsesJwtCallback(t *testing.T) {
   213  	baseDir, err := os.MkdirTemp("", "")
   214  	assert.NoError(t, err)
   215  	jwtFile := filepath.Join(baseDir, "token")
   216  
   217  	jwtCallback := func() (string, error) {
   218  		jwt, err := os.ReadFile(jwtFile)
   219  		if err != nil {
   220  			return "", fmt.Errorf("failed to read a file with a federated token: %w", err)
   221  		}
   222  		return string(jwt), nil
   223  	}
   224  
   225  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   226  		jwt := r.FormValue("client_assertion")
   227  		refreshToken := r.FormValue("refresh_token")
   228  
   229  		if jwt == "aaa.aaa" {
   230  			w.Write([]byte(`{"access_token":"A","expires_in":"3600"}`))
   231  		} else if jwt == "bbb.bbb" {
   232  			w.Write([]byte(`{"access_token":"B","expires_in":"3600","refresh_token":"R"}`))
   233  		} else if refreshToken == "R" {
   234  			w.Write([]byte(`{"access_token":"C","expires_in":"3600"}`))
   235  		} else {
   236  			w.WriteHeader(http.StatusBadRequest)
   237  		}
   238  	}))
   239  
   240  	spt := newServicePrincipalTokenFederatedJwtCallback(t, jwtCallback, server.URL)
   241  
   242  	// token file does not exist, no such file error
   243  	err = spt.refreshInternal(context.Background(), "")
   244  	assert.ErrorIs(t, err, os.ErrNotExist)
   245  
   246  	// get jwt token from jwtFile
   247  	err = os.WriteFile(jwtFile, []byte("aaa.aaa"), 0600)
   248  	assert.NoError(t, err)
   249  	err = spt.refreshInternal(context.Background(), "")
   250  	assert.NoError(t, err)
   251  	assert.Equal(t, "A", spt.inner.Token.AccessToken)
   252  
   253  	// jwtFile is refreshed
   254  	err = os.WriteFile(jwtFile, []byte("bbb.bbb"), 0600)
   255  	assert.NoError(t, err)
   256  	err = spt.refreshInternal(context.Background(), "")
   257  	assert.NoError(t, err)
   258  	assert.Equal(t, "B", spt.inner.Token.AccessToken)
   259  	// refresh_token is set
   260  	assert.Equal(t, "R", spt.inner.Token.RefreshToken)
   261  
   262  	// after refresh_token is set, the callback won't be called
   263  	err = spt.refreshInternal(context.Background(), "")
   264  	assert.NoError(t, err)
   265  	assert.Equal(t, "C", spt.inner.Token.AccessToken)
   266  }
   267  
   268  func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) {
   269  	spt := newServicePrincipalToken()
   270  
   271  	body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
   272  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   273  
   274  	c := mocks.NewSender()
   275  	s := DecorateSender(c,
   276  		(func() SendDecorator {
   277  			return func(s Sender) Sender {
   278  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   279  					if r.Method != "POST" {
   280  						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method)
   281  					}
   282  					return resp, nil
   283  				})
   284  			}
   285  		})())
   286  	spt.SetSender(s)
   287  	err := spt.Refresh()
   288  	if err != nil {
   289  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
   290  	}
   291  
   292  	if body.IsOpen() {
   293  		t.Fatalf("the response was not closed!")
   294  	}
   295  }
   296  
   297  func TestNewServicePrincipalTokenFromManagedIdentity(t *testing.T) {
   298  	spt, err := NewServicePrincipalTokenFromManagedIdentity("https://resource", nil)
   299  	if err != nil {
   300  		t.Fatalf("Failed to get MSI SPT: %v", err)
   301  	}
   302  
   303  	body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
   304  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   305  
   306  	c := mocks.NewSender()
   307  	s := DecorateSender(c,
   308  		(func() SendDecorator {
   309  			return func(s Sender) Sender {
   310  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   311  					if r.Method != "GET" {
   312  						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
   313  					}
   314  					if h := r.Header.Get("Metadata"); h != "true" {
   315  						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Metadata header for MSI")
   316  					}
   317  					return resp, nil
   318  				})
   319  			}
   320  		})())
   321  	spt.SetSender(s)
   322  	err = spt.Refresh()
   323  	if err != nil {
   324  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
   325  	}
   326  
   327  	if body.IsOpen() {
   328  		t.Fatalf("the response was not closed!")
   329  	}
   330  }
   331  
   332  func TestServicePrincipalTokenFromMSICloudshell(t *testing.T) {
   333  	os.Setenv(msiEndpointEnv, "http://dummy")
   334  	defer func() {
   335  		os.Unsetenv(msiEndpointEnv)
   336  	}()
   337  	spt, err := NewServicePrincipalTokenFromMSI("", "https://resource")
   338  	if err != nil {
   339  		t.Fatalf("Failed to get MSI SPT: %v", err)
   340  	}
   341  
   342  	body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
   343  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   344  
   345  	c := mocks.NewSender()
   346  	s := DecorateSender(c,
   347  		(func() SendDecorator {
   348  			return func(s Sender) Sender {
   349  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   350  					if r.Method != http.MethodPost {
   351  						t.Fatalf("adal: cloudshell did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
   352  					}
   353  					if h := r.Header.Get("Metadata"); h != "true" {
   354  						t.Fatalf("adal: cloudshell did not correctly set Metadata header")
   355  					}
   356  					if h := r.Header.Get("Content-Type"); h != "application/x-www-form-urlencoded" {
   357  						t.Fatalf("adal: cloudshell did not correctly set Content-Type header")
   358  					}
   359  					return resp, nil
   360  				})
   361  			}
   362  		})())
   363  	spt.SetSender(s)
   364  	err = spt.Refresh()
   365  	if err != nil {
   366  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
   367  	}
   368  
   369  	if body.IsOpen() {
   370  		t.Fatalf("the response was not closed!")
   371  	}
   372  }
   373  
   374  func TestServicePrincipalTokenFromMSIRefreshZeroRetry(t *testing.T) {
   375  	resource := "https://resource"
   376  	cb := func(token Token) error { return nil }
   377  
   378  	endpoint, _ := GetMSIVMEndpoint()
   379  	spt, err := NewServicePrincipalTokenFromMSI(endpoint, resource, cb)
   380  	if err != nil {
   381  		t.Fatalf("Failed to get MSI SPT: %v", err)
   382  	}
   383  	spt.MaxMSIRefreshAttempts = 1
   384  
   385  	body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
   386  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   387  
   388  	c := mocks.NewSender()
   389  	s := DecorateSender(c,
   390  		(func() SendDecorator {
   391  			return func(s Sender) Sender {
   392  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   393  					// second invocation, perform MSI request validation
   394  					if r.Method != "GET" {
   395  						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
   396  					}
   397  					if h := r.Header.Get("Metadata"); h != "true" {
   398  						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Metadata header for MSI")
   399  					}
   400  					return resp, nil
   401  				})
   402  			}
   403  		})())
   404  	spt.SetSender(s)
   405  	err = spt.Refresh()
   406  	if err != nil {
   407  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
   408  	}
   409  
   410  	if body.IsOpen() {
   411  		t.Fatalf("the response was not closed!")
   412  	}
   413  }
   414  
   415  func TestServicePrincipalTokenFromASE(t *testing.T) {
   416  	os.Setenv("MSI_ENDPOINT", "http://localhost")
   417  	os.Setenv("MSI_SECRET", "super")
   418  	defer func() {
   419  		os.Unsetenv("MSI_ENDPOINT")
   420  		os.Unsetenv("MSI_SECRET")
   421  	}()
   422  	resource := "https://resource"
   423  	spt, err := NewServicePrincipalTokenFromMSI("", resource)
   424  	if err != nil {
   425  		t.Fatalf("Failed to get MSI SPT: %v", err)
   426  	}
   427  	spt.MaxMSIRefreshAttempts = 1
   428  	// expires_on is sent in UTC
   429  	nowTime := time.Now()
   430  	expiresOn := nowTime.UTC().Add(time.Hour)
   431  	// use int format for expires_in
   432  	body := mocks.NewBody(newTokenJSON("3600", expiresOn.Format(expiresOnDateFormat), "test"))
   433  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   434  
   435  	c := mocks.NewSender()
   436  	s := DecorateSender(c,
   437  		(func() SendDecorator {
   438  			return func(s Sender) Sender {
   439  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   440  					if r.Method != "GET" {
   441  						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
   442  					}
   443  					if h := r.Header.Get(metadataHeader); h != "" {
   444  						t.Fatalf("adal: ServicePrincipalToken#Refresh incorrectly set Metadata header for ASE")
   445  					}
   446  					if s := r.Header.Get(secretHeader); s != "super" {
   447  						t.Fatalf("adal: unexpected secret header value %s", s)
   448  					}
   449  					if r.URL.Host != "localhost" {
   450  						t.Fatalf("adal: unexpected host %s", r.URL.Host)
   451  					}
   452  					qp := r.URL.Query()
   453  					if api := qp.Get("api-version"); api != appServiceAPIVersion2017 {
   454  						t.Fatalf("adal: unexpected api-version %s", api)
   455  					}
   456  					return resp, nil
   457  				})
   458  			}
   459  		})())
   460  	spt.SetSender(s)
   461  	err = spt.Refresh()
   462  	if err != nil {
   463  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
   464  	}
   465  	v, err := spt.inner.Token.ExpiresOn.Int64()
   466  	if err != nil {
   467  		t.Fatalf("adal: failed to get ExpiresOn %v", err)
   468  	}
   469  	if nowAsUnix := nowTime.Add(time.Hour).Unix(); v != nowAsUnix {
   470  		t.Fatalf("adal: expected %v, got %v", nowAsUnix, v)
   471  	}
   472  	if body.IsOpen() {
   473  		t.Fatalf("the response was not closed!")
   474  	}
   475  }
   476  
   477  func TestServicePrincipalTokenFromADFS(t *testing.T) {
   478  	os.Setenv("MSI_ENDPOINT", "http://localhost")
   479  	os.Setenv("MSI_SECRET", "super")
   480  	defer func() {
   481  		os.Unsetenv("MSI_ENDPOINT")
   482  		os.Unsetenv("MSI_SECRET")
   483  	}()
   484  	resource := "https://resource"
   485  	endpoint, _ := GetMSIEndpoint()
   486  	spt, err := NewServicePrincipalTokenFromMSI(endpoint, resource)
   487  	if err != nil {
   488  		t.Fatalf("Failed to get MSI SPT: %v", err)
   489  	}
   490  	spt.MaxMSIRefreshAttempts = 1
   491  	const expiresIn = 3600
   492  	body := mocks.NewBody(newADFSTokenJSON(expiresIn))
   493  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   494  
   495  	c := mocks.NewSender()
   496  	s := DecorateSender(c,
   497  		(func() SendDecorator {
   498  			return func(s Sender) Sender {
   499  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   500  					if r.Method != "GET" {
   501  						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
   502  					}
   503  					if h := r.Header.Get(metadataHeader); h != "" {
   504  						t.Fatalf("adal: ServicePrincipalToken#Refresh incorrectly set Metadata header for ASE")
   505  					}
   506  					if s := r.Header.Get(secretHeader); s != "super" {
   507  						t.Fatalf("adal: unexpected secret header value %s", s)
   508  					}
   509  					if r.URL.Host != "localhost" {
   510  						t.Fatalf("adal: unexpected host %s", r.URL.Host)
   511  					}
   512  					qp := r.URL.Query()
   513  					if api := qp.Get("api-version"); api != appServiceAPIVersion2017 {
   514  						t.Fatalf("adal: unexpected api-version %s", api)
   515  					}
   516  					return resp, nil
   517  				})
   518  			}
   519  		})())
   520  	spt.SetSender(s)
   521  	err = spt.Refresh()
   522  	if err != nil {
   523  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
   524  	}
   525  	i, err := spt.inner.Token.ExpiresIn.Int64()
   526  	if err != nil {
   527  		t.Fatalf("unexpected parsing of expires_in: %v", err)
   528  	}
   529  	if i != expiresIn {
   530  		t.Fatalf("unexpected expires_in %d", i)
   531  	}
   532  	if spt.inner.Token.ExpiresOn.String() != "" {
   533  		t.Fatal("expected empty expires_on")
   534  	}
   535  	if body.IsOpen() {
   536  		t.Fatalf("the response was not closed!")
   537  	}
   538  }
   539  
   540  func TestServicePrincipalTokenFromMSIRefreshCancel(t *testing.T) {
   541  	ctx, cancel := context.WithCancel(context.Background())
   542  	endpoint, _ := GetMSIVMEndpoint()
   543  
   544  	spt, err := NewServicePrincipalTokenFromMSI(endpoint, "https://resource")
   545  	if err != nil {
   546  		t.Fatalf("Failed to get MSI SPT: %v", err)
   547  	}
   548  
   549  	c := mocks.NewSender()
   550  	c.AppendAndRepeatResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError), 5)
   551  
   552  	var wg sync.WaitGroup
   553  	wg.Add(1)
   554  	start := time.Now()
   555  	end := time.Now()
   556  
   557  	go func() {
   558  		spt.SetSender(c)
   559  		err = spt.RefreshWithContext(ctx)
   560  		end = time.Now()
   561  		wg.Done()
   562  	}()
   563  
   564  	cancel()
   565  	wg.Wait()
   566  	time.Sleep(5 * time.Millisecond)
   567  
   568  	if end.Sub(start) >= time.Second {
   569  		t.Fatalf("TestServicePrincipalTokenFromMSIRefreshCancel failed to cancel")
   570  	}
   571  }
   572  
   573  func TestServicePrincipalTokenRefreshSetsMimeType(t *testing.T) {
   574  	spt := newServicePrincipalToken()
   575  
   576  	body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
   577  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   578  
   579  	c := mocks.NewSender()
   580  	s := DecorateSender(c,
   581  		(func() SendDecorator {
   582  			return func(s Sender) Sender {
   583  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   584  					if r.Header.Get(http.CanonicalHeaderKey("Content-Type")) != "application/x-www-form-urlencoded" {
   585  						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Content-Type -- expected %v, received %v",
   586  							"application/x-form-urlencoded",
   587  							r.Header.Get(http.CanonicalHeaderKey("Content-Type")))
   588  					}
   589  					return resp, nil
   590  				})
   591  			}
   592  		})())
   593  	spt.SetSender(s)
   594  	err := spt.Refresh()
   595  	if err != nil {
   596  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
   597  	}
   598  }
   599  
   600  func TestServicePrincipalTokenRefreshSetsURL(t *testing.T) {
   601  	spt := newServicePrincipalToken()
   602  
   603  	body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
   604  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   605  
   606  	c := mocks.NewSender()
   607  	s := DecorateSender(c,
   608  		(func() SendDecorator {
   609  			return func(s Sender) Sender {
   610  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   611  					if r.URL.String() != TestOAuthConfig.TokenEndpoint.String() {
   612  						t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the URL -- expected %v, received %v",
   613  							TestOAuthConfig.TokenEndpoint, r.URL)
   614  					}
   615  					return resp, nil
   616  				})
   617  			}
   618  		})())
   619  	spt.SetSender(s)
   620  	err := spt.Refresh()
   621  	if err != nil {
   622  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
   623  	}
   624  }
   625  
   626  func testServicePrincipalTokenRefreshSetsBody(t *testing.T, spt *ServicePrincipalToken, f func(*testing.T, []byte)) {
   627  	body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
   628  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   629  
   630  	c := mocks.NewSender()
   631  	s := DecorateSender(c,
   632  		(func() SendDecorator {
   633  			return func(s Sender) Sender {
   634  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   635  					b, err := ioutil.ReadAll(r.Body)
   636  					if err != nil {
   637  						t.Fatalf("adal: Failed to read body of Service Principal token request (%v)", err)
   638  					}
   639  					f(t, b)
   640  					return resp, nil
   641  				})
   642  			}
   643  		})())
   644  	spt.SetSender(s)
   645  	err := spt.Refresh()
   646  	if err != nil {
   647  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
   648  	}
   649  }
   650  
   651  func TestServicePrincipalTokenManualRefreshSetsBody(t *testing.T) {
   652  	sptManual := newServicePrincipalTokenManual()
   653  	testServicePrincipalTokenRefreshSetsBody(t, sptManual, func(t *testing.T, b []byte) {
   654  		if string(b) != defaultManualFormData {
   655  			t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
   656  				defaultManualFormData, string(b))
   657  		}
   658  	})
   659  }
   660  
   661  func TestServicePrincipalTokenCertficateRefreshSetsBody(t *testing.T) {
   662  	sptCert := newServicePrincipalTokenCertificate(t)
   663  	testServicePrincipalTokenRefreshSetsBody(t, sptCert, func(t *testing.T, b []byte) {
   664  		body := string(b)
   665  
   666  		values, _ := url.ParseQuery(body)
   667  		if values["client_assertion_type"][0] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ||
   668  			values["client_id"][0] != "id" ||
   669  			values["grant_type"][0] != "client_credentials" ||
   670  			values["resource"][0] != "resource" {
   671  			t.Fatalf("adal: ServicePrincipalTokenCertificate#Refresh did not correctly set the HTTP Request Body.")
   672  		}
   673  
   674  		tok, _ := jwt.Parse(values["client_assertion"][0], nil)
   675  		if tok == nil {
   676  			t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to be a JWT")
   677  		}
   678  		if _, ok := tok.Header["x5t"]; !ok {
   679  			t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5t header")
   680  		}
   681  		if _, ok := tok.Header["x5c"]; !ok {
   682  			t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5c header")
   683  		}
   684  		claims, ok := tok.Claims.(jwt.MapClaims)
   685  		if !ok {
   686  			t.Fatalf("expected MapClaims, got %T", tok.Claims)
   687  		}
   688  		if err := claims.Valid(); err != nil {
   689  			t.Fatalf("invalid claim: %v", err)
   690  		}
   691  		if aud := claims["aud"]; aud != "https://login.test.com/SomeTenantID/oauth2/token?api-version=1.0" {
   692  			t.Fatalf("unexpected aud: %s", aud)
   693  		}
   694  		if iss := claims["iss"]; iss != "id" {
   695  			t.Fatalf("unexpected iss: %s", iss)
   696  		}
   697  		if sub := claims["sub"]; sub != "id" {
   698  			t.Fatalf("unexpected sub: %s", sub)
   699  		}
   700  	})
   701  }
   702  
   703  func TestServicePrincipalTokenUsernamePasswordRefreshSetsBody(t *testing.T) {
   704  	spt := newServicePrincipalTokenUsernamePassword(t)
   705  	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
   706  		body := string(b)
   707  
   708  		values, _ := url.ParseQuery(body)
   709  		if values["client_id"][0] != "id" ||
   710  			values["grant_type"][0] != "password" ||
   711  			values["username"][0] != "username" ||
   712  			values["password"][0] != "password" ||
   713  			values["resource"][0] != "resource" {
   714  			t.Fatalf("adal: ServicePrincipalTokenUsernamePassword#Refresh did not correctly set the HTTP Request Body.")
   715  		}
   716  	})
   717  }
   718  
   719  func TestServicePrincipalTokenAuthorizationCodeRefreshSetsBody(t *testing.T) {
   720  	spt := newServicePrincipalTokenAuthorizationCode(t)
   721  	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
   722  		body := string(b)
   723  
   724  		values, _ := url.ParseQuery(body)
   725  		if values["client_id"][0] != "id" ||
   726  			values["grant_type"][0] != OAuthGrantTypeAuthorizationCode ||
   727  			values["code"][0] != "code" ||
   728  			values["client_secret"][0] != "clientSecret" ||
   729  			values["redirect_uri"][0] != "http://redirectUri/getToken" ||
   730  			values["resource"][0] != "resource" {
   731  			t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.")
   732  		}
   733  	})
   734  	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
   735  		body := string(b)
   736  
   737  		values, _ := url.ParseQuery(body)
   738  		if values["client_id"][0] != "id" ||
   739  			values["grant_type"][0] != OAuthGrantTypeRefreshToken ||
   740  			values["code"][0] != "code" ||
   741  			values["client_secret"][0] != "clientSecret" ||
   742  			values["redirect_uri"][0] != "http://redirectUri/getToken" ||
   743  			values["resource"][0] != "resource" {
   744  			t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.")
   745  		}
   746  	})
   747  }
   748  
   749  func TestServicePrincipalTokenSecretRefreshSetsBody(t *testing.T) {
   750  	spt := newServicePrincipalToken()
   751  	testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
   752  		if string(b) != defaultFormData {
   753  			t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
   754  				defaultFormData, string(b))
   755  		}
   756  
   757  	})
   758  }
   759  
   760  func TestServicePrincipalTokenFederatedJwtRefreshSetsBody(t *testing.T) {
   761  	sptCert := newServicePrincipalTokenFederatedJwt(t)
   762  	testServicePrincipalTokenRefreshSetsBody(t, sptCert, func(t *testing.T, b []byte) {
   763  		body := string(b)
   764  
   765  		values, _ := url.ParseQuery(body)
   766  		if values["client_assertion_type"][0] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ||
   767  			values["client_id"][0] != "id" ||
   768  			values["grant_type"][0] != "client_credentials" ||
   769  			values["resource"][0] != "resource" {
   770  			t.Fatalf("adal: ServicePrincipalTokenCertificate#Refresh did not correctly set the HTTP Request Body.")
   771  		}
   772  
   773  		tok, _ := jwt.Parse(values["client_assertion"][0], nil)
   774  		if tok == nil {
   775  			t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to be a JWT")
   776  		}
   777  		if _, ok := tok.Header["typ"]; !ok {
   778  			t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an typ header")
   779  		}
   780  
   781  		claims, ok := tok.Claims.(jwt.MapClaims)
   782  		if !ok {
   783  			t.Fatalf("expected MapClaims, got %T", tok.Claims)
   784  		}
   785  		if err := claims.Valid(); err != nil {
   786  			t.Fatalf("invalid claim: %v", err)
   787  		}
   788  		if aud := claims["aud"]; aud != "testAudience" {
   789  			t.Fatalf("unexpected aud: %s", aud)
   790  		}
   791  		if iss := claims["iss"]; iss != "id" {
   792  			t.Fatalf("unexpected iss: %s", iss)
   793  		}
   794  		if sub := claims["sub"]; sub != "id" {
   795  			t.Fatalf("unexpected sub: %s", sub)
   796  		}
   797  	})
   798  }
   799  
   800  func TestServicePrincipalTokenRefreshClosesRequestBody(t *testing.T) {
   801  	spt := newServicePrincipalToken()
   802  
   803  	body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
   804  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   805  
   806  	c := mocks.NewSender()
   807  	s := DecorateSender(c,
   808  		(func() SendDecorator {
   809  			return func(s Sender) Sender {
   810  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   811  					return resp, nil
   812  				})
   813  			}
   814  		})())
   815  	spt.SetSender(s)
   816  	err := spt.Refresh()
   817  	if err != nil {
   818  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
   819  	}
   820  	if resp.Body.(*mocks.Body).IsOpen() {
   821  		t.Fatal("adal: ServicePrincipalToken#Refresh failed to close the HTTP Response Body")
   822  	}
   823  }
   824  
   825  func TestServicePrincipalTokenRefreshRejectsResponsesWithStatusNotOK(t *testing.T) {
   826  	spt := newServicePrincipalToken()
   827  
   828  	body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
   829  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusUnauthorized, "Unauthorized")
   830  
   831  	c := mocks.NewSender()
   832  	s := DecorateSender(c,
   833  		(func() SendDecorator {
   834  			return func(s Sender) Sender {
   835  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   836  					return resp, nil
   837  				})
   838  			}
   839  		})())
   840  	spt.SetSender(s)
   841  	err := spt.Refresh()
   842  	if err == nil {
   843  		t.Fatalf("adal: ServicePrincipalToken#Refresh should reject a response with status != %d", http.StatusOK)
   844  	}
   845  }
   846  
   847  func TestServicePrincipalTokenRefreshRejectsEmptyBody(t *testing.T) {
   848  	spt := newServicePrincipalToken()
   849  
   850  	c := mocks.NewSender()
   851  	s := DecorateSender(c,
   852  		(func() SendDecorator {
   853  			return func(s Sender) Sender {
   854  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   855  					return mocks.NewResponse(), nil
   856  				})
   857  			}
   858  		})())
   859  	spt.SetSender(s)
   860  	err := spt.Refresh()
   861  	if err == nil {
   862  		t.Fatal("adal: ServicePrincipalToken#Refresh should reject an empty token")
   863  	}
   864  }
   865  
   866  func TestServicePrincipalTokenRefreshPropagatesErrors(t *testing.T) {
   867  	spt := newServicePrincipalToken()
   868  
   869  	c := mocks.NewSender()
   870  	c.SetError(fmt.Errorf("Faux Error"))
   871  	spt.SetSender(c)
   872  
   873  	err := spt.Refresh()
   874  	if err == nil {
   875  		t.Fatal("adal: Failed to propagate the request error")
   876  	}
   877  }
   878  
   879  func TestServicePrincipalTokenRefreshReturnsErrorIfNotOk(t *testing.T) {
   880  	spt := newServicePrincipalToken()
   881  
   882  	c := mocks.NewSender()
   883  	c.AppendResponse(mocks.NewResponseWithStatus("401 NotAuthorized", http.StatusUnauthorized))
   884  	spt.SetSender(c)
   885  
   886  	err := spt.Refresh()
   887  	if err == nil {
   888  		t.Fatalf("adal: Failed to return an when receiving a status code other than HTTP %d", http.StatusOK)
   889  	}
   890  }
   891  
   892  func TestServicePrincipalTokenRefreshUnmarshals(t *testing.T) {
   893  	spt := newServicePrincipalToken()
   894  
   895  	expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
   896  	j := newTokenJSON(`"3600"`, expiresOn, "resource")
   897  	resp := mocks.NewResponseWithContent(j)
   898  	c := mocks.NewSender()
   899  	s := DecorateSender(c,
   900  		(func() SendDecorator {
   901  			return func(s Sender) Sender {
   902  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   903  					return resp, nil
   904  				})
   905  			}
   906  		})())
   907  	spt.SetSender(s)
   908  
   909  	err := spt.Refresh()
   910  	if err != nil {
   911  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
   912  	} else if spt.inner.Token.AccessToken != "accessToken" ||
   913  		spt.inner.Token.ExpiresIn != "3600" ||
   914  		spt.inner.Token.ExpiresOn != json.Number(expiresOn) ||
   915  		spt.inner.Token.NotBefore != json.Number(expiresOn) ||
   916  		spt.inner.Token.Resource != "resource" ||
   917  		spt.inner.Token.Type != "Bearer" {
   918  		t.Fatalf("adal: ServicePrincipalToken#Refresh failed correctly unmarshal the JSON -- expected %v, received %v",
   919  			j, *spt)
   920  	}
   921  }
   922  
   923  func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) {
   924  	spt := newServicePrincipalToken()
   925  	expireToken(&spt.inner.Token)
   926  
   927  	body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
   928  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   929  
   930  	f := false
   931  	c := mocks.NewSender()
   932  	s := DecorateSender(c,
   933  		(func() SendDecorator {
   934  			return func(s Sender) Sender {
   935  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   936  					f = true
   937  					return resp, nil
   938  				})
   939  			}
   940  		})())
   941  	spt.SetSender(s)
   942  	err := spt.EnsureFresh()
   943  	if err != nil {
   944  		t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
   945  	}
   946  	if !f {
   947  		t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token")
   948  	}
   949  }
   950  
   951  func TestServicePrincipalTokenEnsureFreshWithIntExpiresOn(t *testing.T) {
   952  	spt := newServicePrincipalToken()
   953  	expireToken(&spt.inner.Token)
   954  
   955  	body := mocks.NewBody(newTokenJSONIntExpiresOn(`"3600"`, 12345, "test"))
   956  	resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
   957  
   958  	f := false
   959  	c := mocks.NewSender()
   960  	s := DecorateSender(c,
   961  		(func() SendDecorator {
   962  			return func(s Sender) Sender {
   963  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
   964  					f = true
   965  					return resp, nil
   966  				})
   967  			}
   968  		})())
   969  	spt.SetSender(s)
   970  	err := spt.EnsureFresh()
   971  	if err != nil {
   972  		t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
   973  	}
   974  	if !f {
   975  		t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token")
   976  	}
   977  }
   978  
   979  func TestServicePrincipalTokenEnsureFreshFails1(t *testing.T) {
   980  	spt := newServicePrincipalToken()
   981  	expireToken(&spt.inner.Token)
   982  
   983  	c := mocks.NewSender()
   984  	c.SetError(fmt.Errorf("some failure"))
   985  
   986  	spt.SetSender(c)
   987  	err := spt.EnsureFresh()
   988  	if err == nil {
   989  		t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return an error")
   990  	}
   991  	if _, ok := err.(TokenRefreshError); ok {
   992  		t.Fatal("adal: ServicePrincipalToken#EnsureFresh unexpected TokenRefreshError")
   993  	}
   994  }
   995  
   996  func TestServicePrincipalTokenEnsureFreshFails2(t *testing.T) {
   997  	spt := newServicePrincipalToken()
   998  	expireToken(&spt.inner.Token)
   999  
  1000  	c := mocks.NewSender()
  1001  	c.AppendResponse(mocks.NewResponseWithStatus("bad request", http.StatusBadRequest))
  1002  
  1003  	spt.SetSender(c)
  1004  	err := spt.EnsureFresh()
  1005  	if err == nil {
  1006  		t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return an error")
  1007  	}
  1008  	if _, ok := err.(TokenRefreshError); !ok {
  1009  		t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return a TokenRefreshError")
  1010  	}
  1011  }
  1012  
  1013  func TestServicePrincipalTokenEnsureFreshSkipsIfFresh(t *testing.T) {
  1014  	spt := newServicePrincipalToken()
  1015  	setTokenToExpireIn(&spt.inner.Token, 1000*time.Second)
  1016  
  1017  	f := false
  1018  	c := mocks.NewSender()
  1019  	s := DecorateSender(c,
  1020  		(func() SendDecorator {
  1021  			return func(s Sender) Sender {
  1022  				return SenderFunc(func(r *http.Request) (*http.Response, error) {
  1023  					f = true
  1024  					return mocks.NewResponse(), nil
  1025  				})
  1026  			}
  1027  		})())
  1028  	spt.SetSender(s)
  1029  	err := spt.EnsureFresh()
  1030  	if err != nil {
  1031  		t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
  1032  	}
  1033  	if f {
  1034  		t.Fatal("adal: ServicePrincipalToken#EnsureFresh invoked Refresh for fresh token")
  1035  	}
  1036  }
  1037  
  1038  func TestRefreshCallback(t *testing.T) {
  1039  	callbackTriggered := false
  1040  	spt := newServicePrincipalToken(func(Token) error {
  1041  		callbackTriggered = true
  1042  		return nil
  1043  	})
  1044  
  1045  	expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
  1046  
  1047  	sender := mocks.NewSender()
  1048  	j := newTokenJSON(`"3600"`, expiresOn, "resource")
  1049  	sender.AppendResponse(mocks.NewResponseWithContent(j))
  1050  	spt.SetSender(sender)
  1051  	err := spt.Refresh()
  1052  	if err != nil {
  1053  		t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
  1054  	}
  1055  	if !callbackTriggered {
  1056  		t.Fatalf("adal: RefreshCallback failed to trigger call callback")
  1057  	}
  1058  }
  1059  
  1060  func TestRefreshCallbackErrorPropagates(t *testing.T) {
  1061  	errorText := "this is an error text"
  1062  	spt := newServicePrincipalToken(func(Token) error {
  1063  		return fmt.Errorf(errorText)
  1064  	})
  1065  
  1066  	expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
  1067  
  1068  	sender := mocks.NewSender()
  1069  	j := newTokenJSON(`"3600"`, expiresOn, "resource")
  1070  	sender.AppendResponse(mocks.NewResponseWithContent(j))
  1071  	spt.SetSender(sender)
  1072  	err := spt.Refresh()
  1073  
  1074  	if err == nil || !strings.Contains(err.Error(), errorText) {
  1075  		t.Fatalf("adal: RefreshCallback failed to propagate error")
  1076  	}
  1077  }
  1078  
  1079  // This demonstrates the danger of manual token without a refresh token
  1080  func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) {
  1081  	spt := newServicePrincipalTokenManual()
  1082  	spt.inner.Token.RefreshToken = ""
  1083  	err := spt.Refresh()
  1084  	if err == nil {
  1085  		t.Fatalf("adal: ServicePrincipalToken#Refresh should have failed with a ManualTokenSecret without a refresh token")
  1086  	}
  1087  }
  1088  
  1089  func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
  1090  	const resource = "https://resource"
  1091  	cb := func(token Token) error { return nil }
  1092  
  1093  	spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
  1094  	if err != nil {
  1095  		t.Fatalf("Failed to get MSI SPT: %v", err)
  1096  	}
  1097  
  1098  	// check some of the SPT fields
  1099  	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
  1100  		t.Fatal("SPT secret was not of MSI type")
  1101  	}
  1102  
  1103  	if spt.inner.Resource != resource {
  1104  		t.Fatal("SPT came back with incorrect resource")
  1105  	}
  1106  
  1107  	if len(spt.refreshCallbacks) != 1 {
  1108  		t.Fatal("SPT had incorrect refresh callbacks.")
  1109  	}
  1110  }
  1111  
  1112  func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
  1113  	const (
  1114  		resource = "https://resource"
  1115  		userID   = "abc123"
  1116  	)
  1117  	cb := func(token Token) error { return nil }
  1118  
  1119  	spt, err := NewServicePrincipalTokenFromMSIWithUserAssignedID("http://msiendpoint/", resource, userID, cb)
  1120  	if err != nil {
  1121  		t.Fatalf("Failed to get MSI SPT: %v", err)
  1122  	}
  1123  
  1124  	// check some of the SPT fields
  1125  	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
  1126  		t.Fatal("SPT secret was not of MSI type")
  1127  	}
  1128  
  1129  	if spt.inner.Resource != resource {
  1130  		t.Fatal("SPT came back with incorrect resource")
  1131  	}
  1132  
  1133  	if len(spt.refreshCallbacks) != 1 {
  1134  		t.Fatal("SPT had incorrect refresh callbacks.")
  1135  	}
  1136  
  1137  	if spt.inner.ClientID != userID {
  1138  		t.Fatal("SPT had incorrect client ID")
  1139  	}
  1140  }
  1141  
  1142  func TestNewServicePrincipalTokenFromMSIWithIdentityResourceID(t *testing.T) {
  1143  	const (
  1144  		resource           = "https://resource"
  1145  		identityResourceID = "/subscriptions/testSub/resourceGroups/testGroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test-identity"
  1146  	)
  1147  	cb := func(token Token) error { return nil }
  1148  
  1149  	spt, err := NewServicePrincipalTokenFromMSIWithIdentityResourceID("http://msiendpoint/", resource, identityResourceID, cb)
  1150  	if err != nil {
  1151  		t.Fatalf("Failed to get MSI SPT: %v", err)
  1152  	}
  1153  
  1154  	// check some of the SPT fields
  1155  	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
  1156  		t.Fatal("SPT secret was not of MSI type")
  1157  	}
  1158  
  1159  	if spt.inner.Resource != resource {
  1160  		t.Fatal("SPT came back with incorrect resource")
  1161  	}
  1162  
  1163  	if len(spt.refreshCallbacks) != 1 {
  1164  		t.Fatal("SPT had incorrect refresh callbacks.")
  1165  	}
  1166  
  1167  	urlPathParameter := url.Values{}
  1168  	urlPathParameter.Set("mi_res_id", identityResourceID)
  1169  
  1170  	if !strings.Contains(spt.inner.OauthConfig.TokenEndpoint.RawQuery, urlPathParameter.Encode()) {
  1171  		t.Fatal("SPT tokenEndpoint should contains mi_res_id")
  1172  	}
  1173  }
  1174  
  1175  func TestNewServicePrincipalTokenFromManualTokenSecret(t *testing.T) {
  1176  	token := newToken()
  1177  	secret := &ServicePrincipalAuthorizationCodeSecret{
  1178  		ClientSecret:      "clientSecret",
  1179  		AuthorizationCode: "code123",
  1180  		RedirectURI:       "redirect",
  1181  	}
  1182  
  1183  	spt, err := NewServicePrincipalTokenFromManualTokenSecret(TestOAuthConfig, "id", "resource", token, secret, nil)
  1184  	if err != nil {
  1185  		t.Fatalf("Failed creating new SPT: %s", err)
  1186  	}
  1187  
  1188  	if !reflect.DeepEqual(token, spt.inner.Token) {
  1189  		t.Fatalf("Tokens do not match: %s, %s", token, spt.inner.Token)
  1190  	}
  1191  
  1192  	if !reflect.DeepEqual(secret, spt.inner.Secret) {
  1193  		t.Fatalf("Secrets do not match: %s, %s", secret, spt.inner.Secret)
  1194  	}
  1195  
  1196  }
  1197  
  1198  func TestGetVMEndpoint(t *testing.T) {
  1199  	endpoint, err := GetMSIVMEndpoint()
  1200  	if err != nil {
  1201  		t.Fatal("Coudn't get VM endpoint")
  1202  	}
  1203  
  1204  	if endpoint != msiEndpoint {
  1205  		t.Fatal("Didn't get correct endpoint")
  1206  	}
  1207  }
  1208  
  1209  func TestGetAppServiceEndpoint(t *testing.T) {
  1210  	const testEndpoint = "http://172.16.1.2:8081/msi/token"
  1211  	const aseSecret = "the_secret"
  1212  	if err := os.Setenv(msiEndpointEnv, testEndpoint); err != nil {
  1213  		t.Fatalf("os.Setenv: %v", err)
  1214  	}
  1215  	if err := os.Setenv(msiSecretEnv, aseSecret); err != nil {
  1216  		t.Fatalf("os.Setenv: %v", err)
  1217  	}
  1218  	defer func() {
  1219  		os.Unsetenv(msiEndpointEnv)
  1220  		os.Unsetenv(msiSecretEnv)
  1221  	}()
  1222  
  1223  	endpoint, err := GetMSIAppServiceEndpoint()
  1224  	if err != nil {
  1225  		t.Fatal("Coudn't get App Service endpoint")
  1226  	}
  1227  
  1228  	if endpoint != testEndpoint {
  1229  		t.Fatal("Didn't get correct endpoint")
  1230  	}
  1231  }
  1232  
  1233  func TestGetMSIEndpoint(t *testing.T) {
  1234  	const (
  1235  		testEndpoint = "http://172.16.1.2:8081/msi/token"
  1236  		testSecret   = "DEADBEEF-BBBB-AAAA-DDDD-DDD000000DDD"
  1237  	)
  1238  
  1239  	// Test VM well-known endpoint is returned
  1240  	if err := os.Unsetenv(msiEndpointEnv); err != nil {
  1241  		t.Fatalf("os.Unsetenv: %v", err)
  1242  	}
  1243  
  1244  	if err := os.Unsetenv(msiSecretEnv); err != nil {
  1245  		t.Fatalf("os.Unsetenv: %v", err)
  1246  	}
  1247  
  1248  	vmEndpoint, err := GetMSIEndpoint()
  1249  	if err != nil {
  1250  		t.Fatal("Coudn't get VM endpoint")
  1251  	}
  1252  
  1253  	if vmEndpoint != msiEndpoint {
  1254  		t.Fatal("Didn't get correct endpoint")
  1255  	}
  1256  
  1257  	// Test App Service endpoint is returned
  1258  	if err := os.Setenv(msiEndpointEnv, testEndpoint); err != nil {
  1259  		t.Fatalf("os.Setenv: %v", err)
  1260  	}
  1261  
  1262  	if err := os.Setenv(msiSecretEnv, testSecret); err != nil {
  1263  		t.Fatalf("os.Setenv: %v", err)
  1264  	}
  1265  
  1266  	asEndpoint, err := GetMSIEndpoint()
  1267  	if err != nil {
  1268  		t.Fatal("Coudn't get App Service endpoint")
  1269  	}
  1270  
  1271  	if asEndpoint != testEndpoint {
  1272  		t.Fatal("Didn't get correct endpoint")
  1273  	}
  1274  
  1275  	if err := os.Unsetenv(msiEndpointEnv); err != nil {
  1276  		t.Fatalf("os.Unsetenv: %v", err)
  1277  	}
  1278  
  1279  	if err := os.Unsetenv(msiSecretEnv); err != nil {
  1280  		t.Fatalf("os.Unsetenv: %v", err)
  1281  	}
  1282  }
  1283  
  1284  func TestClientSecretWithASESet(t *testing.T) {
  1285  	if err := os.Setenv(msiEndpointEnv, "http://172.16.1.2:8081/msi/token"); err != nil {
  1286  		t.Fatalf("os.Setenv: %v", err)
  1287  	}
  1288  	if err := os.Setenv(msiSecretEnv, "the_secret"); err != nil {
  1289  		t.Fatalf("os.Setenv: %v", err)
  1290  	}
  1291  	defer func() {
  1292  		os.Unsetenv(msiEndpointEnv)
  1293  		os.Unsetenv(msiSecretEnv)
  1294  	}()
  1295  	spt := newServicePrincipalToken()
  1296  	if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
  1297  		t.Fatal("should not have MSI secret for client secret token even when ASE is enabled")
  1298  	}
  1299  }
  1300  
  1301  func TestMarshalServicePrincipalNoSecret(t *testing.T) {
  1302  	spt := newServicePrincipalTokenManual()
  1303  	b, err := json.Marshal(spt)
  1304  	if err != nil {
  1305  		t.Fatalf("failed to marshal token: %+v", err)
  1306  	}
  1307  	var spt2 *ServicePrincipalToken
  1308  	err = json.Unmarshal(b, &spt2)
  1309  	if err != nil {
  1310  		t.Fatalf("failed to unmarshal token: %+v", err)
  1311  	}
  1312  	if !reflect.DeepEqual(spt, spt2) {
  1313  		t.Fatal("tokens don't match")
  1314  	}
  1315  }
  1316  
  1317  func TestMarshalServicePrincipalTokenSecret(t *testing.T) {
  1318  	spt := newServicePrincipalToken()
  1319  	b, err := json.Marshal(spt)
  1320  	if err != nil {
  1321  		t.Fatalf("failed to marshal token: %+v", err)
  1322  	}
  1323  	var spt2 *ServicePrincipalToken
  1324  	err = json.Unmarshal(b, &spt2)
  1325  	if err != nil {
  1326  		t.Fatalf("failed to unmarshal token: %+v", err)
  1327  	}
  1328  	if !reflect.DeepEqual(spt, spt2) {
  1329  		t.Fatal("tokens don't match")
  1330  	}
  1331  }
  1332  
  1333  func TestMarshalServicePrincipalCertificateSecret(t *testing.T) {
  1334  	spt := newServicePrincipalTokenCertificate(t)
  1335  	b, err := json.Marshal(spt)
  1336  	if err == nil {
  1337  		t.Fatal("expected error when marshalling certificate token")
  1338  	}
  1339  	var spt2 *ServicePrincipalToken
  1340  	err = json.Unmarshal(b, &spt2)
  1341  	if err == nil {
  1342  		t.Fatal("expected error when unmarshalling certificate token")
  1343  	}
  1344  }
  1345  
  1346  func TestMarshalServicePrincipalMSISecret(t *testing.T) {
  1347  	spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", "", "")
  1348  	if err != nil {
  1349  		t.Fatalf("failed to get MSI SPT: %+v", err)
  1350  	}
  1351  	b, err := json.Marshal(spt)
  1352  	if err == nil {
  1353  		t.Fatal("expected error when marshalling MSI token")
  1354  	}
  1355  	var spt2 *ServicePrincipalToken
  1356  	err = json.Unmarshal(b, &spt2)
  1357  	if err == nil {
  1358  		t.Fatal("expected error when unmarshalling MSI token")
  1359  	}
  1360  }
  1361  
  1362  func TestMarshalServicePrincipalUsernamePasswordSecret(t *testing.T) {
  1363  	spt := newServicePrincipalTokenUsernamePassword(t)
  1364  	b, err := json.Marshal(spt)
  1365  	if err != nil {
  1366  		t.Fatalf("failed to marshal token: %+v", err)
  1367  	}
  1368  	var spt2 *ServicePrincipalToken
  1369  	err = json.Unmarshal(b, &spt2)
  1370  	if err != nil {
  1371  		t.Fatalf("failed to unmarshal token: %+v", err)
  1372  	}
  1373  	if !reflect.DeepEqual(spt, spt2) {
  1374  		t.Fatal("tokens don't match")
  1375  	}
  1376  }
  1377  
  1378  func TestMarshalServicePrincipalAuthorizationCodeSecret(t *testing.T) {
  1379  	spt := newServicePrincipalTokenAuthorizationCode(t)
  1380  	b, err := json.Marshal(spt)
  1381  	if err != nil {
  1382  		t.Fatalf("failed to marshal token: %+v", err)
  1383  	}
  1384  	var spt2 *ServicePrincipalToken
  1385  	err = json.Unmarshal(b, &spt2)
  1386  	if err != nil {
  1387  		t.Fatalf("failed to unmarshal token: %+v", err)
  1388  	}
  1389  	if !reflect.DeepEqual(spt, spt2) {
  1390  		t.Fatal("tokens don't match")
  1391  	}
  1392  }
  1393  
  1394  func TestMarshalServicePrincipalFederatedSecret(t *testing.T) {
  1395  	spt := newServicePrincipalTokenFederatedJwt(t)
  1396  	b, err := json.Marshal(spt)
  1397  	if err == nil {
  1398  		t.Fatal("expected error when marshalling certificate token")
  1399  	}
  1400  	var spt2 *ServicePrincipalToken
  1401  	err = json.Unmarshal(b, &spt2)
  1402  	if err == nil {
  1403  		t.Fatal("expected error when unmarshalling certificate token")
  1404  	}
  1405  }
  1406  
  1407  func TestMarshalInnerToken(t *testing.T) {
  1408  	spt := newServicePrincipalTokenManual()
  1409  	tokenJSON, err := spt.MarshalTokenJSON()
  1410  	if err != nil {
  1411  		t.Fatalf("failed to marshal token: %+v", err)
  1412  	}
  1413  
  1414  	testToken := newToken()
  1415  	testToken.RefreshToken = "refreshtoken"
  1416  
  1417  	testTokenJSON, err := json.Marshal(testToken)
  1418  	if err != nil {
  1419  		t.Fatalf("failed to marshal test token: %+v", err)
  1420  	}
  1421  
  1422  	if !reflect.DeepEqual(tokenJSON, testTokenJSON) {
  1423  		t.Fatalf("tokens don't match: %s, %s", tokenJSON, testTokenJSON)
  1424  	}
  1425  
  1426  	var t1 Token
  1427  	err = json.Unmarshal(tokenJSON, &t1)
  1428  	if err != nil {
  1429  		t.Fatalf("failed to unmarshal token: %+v", err)
  1430  	}
  1431  
  1432  	if !reflect.DeepEqual(t1, testToken) {
  1433  		t.Fatalf("tokens don't match: %s, %s", t1, testToken)
  1434  	}
  1435  }
  1436  
  1437  func TestNewMultiTenantServicePrincipalToken(t *testing.T) {
  1438  	cfg, err := NewMultiTenantOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID, TestAuxTenantIDs, OAuthOptions{})
  1439  	if err != nil {
  1440  		t.Fatalf("autorest/adal: unexpected error while creating multitenant config: %v", err)
  1441  	}
  1442  	mt, err := NewMultiTenantServicePrincipalToken(cfg, "clientID", "superSecret", "resource")
  1443  	if err != nil {
  1444  		t.Fatalf("autorest/adal: unexpected error while creating multitenant service principal token: %v", err)
  1445  	}
  1446  	if !strings.Contains(mt.PrimaryToken.inner.OauthConfig.AuthorizeEndpoint.String(), TestTenantID) {
  1447  		t.Fatal("didn't find primary tenant ID in primary SPT")
  1448  	}
  1449  	for i := range mt.AuxiliaryTokens {
  1450  		if ep := mt.AuxiliaryTokens[i].inner.OauthConfig.AuthorizeEndpoint.String(); !strings.Contains(ep, fmt.Sprintf("%s%d", TestAuxTenantPrefix, i)) {
  1451  			t.Fatalf("didn't find auxiliary tenant ID in token %s", ep)
  1452  		}
  1453  	}
  1454  }
  1455  
  1456  func TestNewMultiTenantServicePrincipalTokenFromCertificate(t *testing.T) {
  1457  	cfg, err := NewMultiTenantOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID, TestAuxTenantIDs, OAuthOptions{})
  1458  	if err != nil {
  1459  		t.Fatalf("autorest/adal: unexpected error while creating multitenant config: %v", err)
  1460  	}
  1461  	cert, key := newTestCertificate(t)
  1462  	mt, err := NewMultiTenantServicePrincipalTokenFromCertificate(cfg, "clientID", cert, key, "resource")
  1463  	if err != nil {
  1464  		t.Fatalf("autorest/adal: unexpected error while creating multitenant service principal token: %v", err)
  1465  	}
  1466  	if !strings.Contains(mt.PrimaryToken.inner.OauthConfig.AuthorizeEndpoint.String(), TestTenantID) {
  1467  		t.Fatal("didn't find primary tenant ID in primary SPT")
  1468  	}
  1469  	for i := range mt.AuxiliaryTokens {
  1470  		if ep := mt.AuxiliaryTokens[i].inner.OauthConfig.AuthorizeEndpoint.String(); !strings.Contains(ep, fmt.Sprintf("%s%d", TestAuxTenantPrefix, i)) {
  1471  			t.Fatalf("didn't find auxiliary tenant ID in token %s", ep)
  1472  		}
  1473  	}
  1474  }
  1475  
  1476  func TestMSIAvailableSuccess(t *testing.T) {
  1477  	c := mocks.NewSender()
  1478  	c.AppendResponse(mocks.NewResponse())
  1479  	if !MSIAvailable(context.Background(), c) {
  1480  		t.Fatal("unexpected false")
  1481  	}
  1482  }
  1483  
  1484  func TestMSIAvailableAppService(t *testing.T) {
  1485  	os.Setenv("MSI_ENDPOINT", "http://localhost")
  1486  	os.Setenv("MSI_SECRET", "super")
  1487  	defer func() {
  1488  		os.Unsetenv("MSI_ENDPOINT")
  1489  		os.Unsetenv("MSI_SECRET")
  1490  	}()
  1491  	c := mocks.NewSender()
  1492  	c.AppendResponse(mocks.NewResponse())
  1493  	available := MSIAvailable(context.Background(), c)
  1494  
  1495  	if !available {
  1496  		t.Fatal("expected MSI to be available")
  1497  	}
  1498  }
  1499  
  1500  func TestMSIAvailableIMDS(t *testing.T) {
  1501  	c := mocks.NewSender()
  1502  	c.AppendResponse(mocks.NewResponse())
  1503  	available := MSIAvailable(context.Background(), c)
  1504  
  1505  	if !available {
  1506  		t.Fatal("expected MSI to be available")
  1507  	}
  1508  }
  1509  
  1510  func TestMSIAvailableSlow(t *testing.T) {
  1511  	c := mocks.NewSender()
  1512  	// introduce a long response delay to simulate the endpoint not being available
  1513  	c.AppendResponseWithDelay(mocks.NewResponse(), 5*time.Second)
  1514  	if MSIAvailable(context.Background(), c) {
  1515  		t.Fatal("unexpected true")
  1516  	}
  1517  }
  1518  
  1519  func TestMSIAvailableFail(t *testing.T) {
  1520  	expectErr := "failed to make msi http request"
  1521  	c := mocks.NewSender()
  1522  	c.AppendAndRepeatError(fmt.Errorf(expectErr), 2)
  1523  	if MSIAvailable(context.Background(), c) {
  1524  		t.Fatal("unexpected true")
  1525  	}
  1526  	_, err := getMSIEndpoint(context.Background(), c)
  1527  	if !strings.Contains(err.Error(), "") {
  1528  		t.Fatalf("expected error: '%s', but got error '%s'", expectErr, err)
  1529  	}
  1530  }
  1531  
  1532  func newTokenJSON(expiresIn, expiresOn, resource string) string {
  1533  	nb, err := parseExpiresOn(expiresOn)
  1534  	if err != nil {
  1535  		panic(err)
  1536  	}
  1537  	return fmt.Sprintf(`{
  1538  		"access_token" : "accessToken",
  1539  		"expires_in"   : %s,
  1540  		"expires_on"   : "%s",
  1541  		"not_before"   : "%s",
  1542  		"resource"     : "%s",
  1543  		"token_type"   : "Bearer",
  1544  		"refresh_token": "ABC123"
  1545  		}`,
  1546  		expiresIn, expiresOn, nb, resource)
  1547  }
  1548  
  1549  func newTokenJSONIntExpiresOn(expiresIn string, expiresOn int, resource string) string {
  1550  	return fmt.Sprintf(`{
  1551  		"access_token" : "accessToken",
  1552  		"expires_in"   : %s,
  1553  		"expires_on"   : %d,
  1554  		"not_before"   : "%d",
  1555  		"resource"     : "%s",
  1556  		"token_type"   : "Bearer",
  1557  		"refresh_token": "ABC123"
  1558  		}`,
  1559  		expiresIn, expiresOn, expiresOn, resource)
  1560  }
  1561  
  1562  func newADFSTokenJSON(expiresIn int) string {
  1563  	return fmt.Sprintf(`{
  1564  		"access_token" : "accessToken",
  1565  		"expires_in"   : %d,
  1566  		"token_type"   : "Bearer"
  1567  		}`,
  1568  		expiresIn)
  1569  }
  1570  
  1571  func newTokenExpiresIn(expireIn time.Duration) *Token {
  1572  	t := newToken()
  1573  	return setTokenToExpireIn(&t, expireIn)
  1574  }
  1575  
  1576  func newTokenExpiresAt(expireAt time.Time) *Token {
  1577  	t := newToken()
  1578  	return setTokenToExpireAt(&t, expireAt)
  1579  }
  1580  
  1581  func expireToken(t *Token) *Token {
  1582  	return setTokenToExpireIn(t, 0)
  1583  }
  1584  
  1585  func setTokenToExpireAt(t *Token, expireAt time.Time) *Token {
  1586  	t.ExpiresIn = "3600"
  1587  	t.ExpiresOn = json.Number(strconv.FormatInt(int64(expireAt.Sub(date.UnixEpoch())/time.Second), 10))
  1588  	t.NotBefore = t.ExpiresOn
  1589  	return t
  1590  }
  1591  
  1592  func setTokenToExpireIn(t *Token, expireIn time.Duration) *Token {
  1593  	return setTokenToExpireAt(t, time.Now().Add(expireIn))
  1594  }
  1595  
  1596  func newServicePrincipalToken(callbacks ...TokenRefreshCallback) *ServicePrincipalToken {
  1597  	spt, _ := NewServicePrincipalToken(TestOAuthConfig, "id", "secret", "resource", callbacks...)
  1598  	return spt
  1599  }
  1600  
  1601  func newServicePrincipalTokenManual() *ServicePrincipalToken {
  1602  	token := newToken()
  1603  	token.RefreshToken = "refreshtoken"
  1604  	spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", token)
  1605  	return spt
  1606  }
  1607  
  1608  func newTestCertificate(t *testing.T) (*x509.Certificate, *rsa.PrivateKey) {
  1609  	template := x509.Certificate{
  1610  		SerialNumber:          big.NewInt(0),
  1611  		Subject:               pkix.Name{CommonName: "test"},
  1612  		BasicConstraintsValid: true,
  1613  	}
  1614  	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
  1615  	if err != nil {
  1616  		t.Fatal(err)
  1617  	}
  1618  	certificateBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
  1619  	if err != nil {
  1620  		t.Fatal(err)
  1621  	}
  1622  	certificate, err := x509.ParseCertificate(certificateBytes)
  1623  	if err != nil {
  1624  		t.Fatal(err)
  1625  	}
  1626  	return certificate, privateKey
  1627  }
  1628  
  1629  func newServicePrincipalTokenCertificate(t *testing.T) *ServicePrincipalToken {
  1630  	certificate, privateKey := newTestCertificate(t)
  1631  
  1632  	spt, _ := NewServicePrincipalTokenFromCertificate(TestOAuthConfig, "id", certificate, privateKey, "resource")
  1633  	return spt
  1634  }
  1635  
  1636  func newServicePrincipalTokenUsernamePassword(t *testing.T) *ServicePrincipalToken {
  1637  	spt, _ := NewServicePrincipalTokenFromUsernamePassword(TestOAuthConfig, "id", "username", "password", "resource")
  1638  	return spt
  1639  }
  1640  
  1641  func newServicePrincipalTokenAuthorizationCode(t *testing.T) *ServicePrincipalToken {
  1642  	spt, _ := NewServicePrincipalTokenFromAuthorizationCode(TestOAuthConfig, "id", "clientSecret", "code", "http://redirectUri/getToken", "resource")
  1643  	return spt
  1644  }
  1645  
  1646  func newServicePrincipalTokenFederatedJwt(t *testing.T) *ServicePrincipalToken {
  1647  	token := jwt.New(jwt.SigningMethodHS256)
  1648  	token.Header["typ"] = "JWT"
  1649  	token.Claims = jwt.MapClaims{
  1650  		"aud": "testAudience",
  1651  		"iss": "id",
  1652  		"sub": "id",
  1653  		"nbf": time.Now().Unix(),
  1654  		"exp": time.Now().Add(24 * time.Hour).Unix(),
  1655  	}
  1656  
  1657  	signedString, err := token.SignedString([]byte("test key"))
  1658  	if err != nil {
  1659  		t.Fatal(err)
  1660  	}
  1661  	spt, _ := NewServicePrincipalTokenFromFederatedToken(TestOAuthConfig, "id", signedString, "resource")
  1662  	return spt
  1663  }
  1664  
  1665  func newServicePrincipalTokenFederatedJwtCallback(t *testing.T, callback JWTCallback, fakeEndpoint string) *ServicePrincipalToken {
  1666  	outhConfig, _ := NewOAuthConfig(fakeEndpoint, TestTenantID)
  1667  	spt, _ := NewServicePrincipalTokenFromFederatedTokenCallback(*outhConfig, "id", callback, "resource")
  1668  	return spt
  1669  }
  1670  

View as plain text