1 package adal
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 import (
18 "context"
19 "crypto/rand"
20 "crypto/rsa"
21 "crypto/x509"
22 "crypto/x509/pkix"
23 "encoding/json"
24 "fmt"
25 "io/ioutil"
26 "math/big"
27 "net/http"
28 "net/http/httptest"
29 "net/url"
30 "os"
31 "path/filepath"
32 "reflect"
33 "strconv"
34 "strings"
35 "sync"
36 "testing"
37 "time"
38
39 "github.com/Azure/go-autorest/autorest/date"
40 "github.com/Azure/go-autorest/autorest/mocks"
41 jwt "github.com/golang-jwt/jwt/v4"
42 "github.com/stretchr/testify/assert"
43 )
44
45 const (
46 defaultFormData = "client_id=id&client_secret=secret&grant_type=client_credentials&resource=resource"
47 defaultManualFormData = "client_id=id&grant_type=refresh_token&refresh_token=refreshtoken&resource=resource"
48 )
49
50 func TestTokenExpires(t *testing.T) {
51 tt := time.Now().Add(5 * time.Second)
52 tk := newTokenExpiresAt(tt)
53
54 if tk.Expires().Equal(tt) {
55 t.Fatalf("adal: Token#Expires miscalculated expiration time -- received %v, expected %v", tk.Expires(), tt)
56 }
57 }
58
59 func TestTokenIsExpired(t *testing.T) {
60 tk := newTokenExpiresAt(time.Now().Add(-5 * time.Second))
61
62 if !tk.IsExpired() {
63 t.Fatalf("adal: Token#IsExpired failed to mark a stale token as expired -- now %v, token expires at %v",
64 time.Now().UTC(), tk.Expires())
65 }
66 }
67
68 func TestTokenIsExpiredUninitialized(t *testing.T) {
69 tk := &Token{}
70
71 if !tk.IsExpired() {
72 t.Fatalf("adal: An uninitialized Token failed to mark itself as expired (expiration time %v)", tk.Expires())
73 }
74 }
75
76 func TestTokenIsNoExpired(t *testing.T) {
77 tk := newTokenExpiresAt(time.Now().Add(1000 * time.Second))
78
79 if tk.IsExpired() {
80 t.Fatalf("adal: Token marked a fresh token as expired -- now %v, token expires at %v", time.Now().UTC(), tk.Expires())
81 }
82 }
83
84 func TestTokenWillExpireIn(t *testing.T) {
85 d := 5 * time.Second
86 tk := newTokenExpiresIn(d)
87
88 if !tk.WillExpireIn(d) {
89 t.Fatal("adal: Token#WillExpireIn mismeasured expiration time")
90 }
91 }
92
93 func TestParseExpiresOn(t *testing.T) {
94 n := time.Now().UTC()
95 amPM := "AM"
96 if n.Hour() >= 12 {
97 amPM = "PM"
98 }
99 testcases := []struct {
100 Name string
101 String string
102 Value int64
103 }{
104 {
105 Name: "integer",
106 String: "3600",
107 Value: 3600,
108 },
109 {
110 Name: "timestamp with AM/PM",
111 String: fmt.Sprintf("%d/%d/%d %d:%02d:%02d %s +00:00", n.Month(), n.Day(), n.Year(), n.Hour(), n.Minute(), n.Second(), amPM),
112 Value: n.Unix(),
113 },
114 {
115 Name: "timestamp without AM/PM",
116 String: fmt.Sprintf("%02d/%02d/%02d %02d:%02d:%02d +00:00", n.Month(), n.Day(), n.Year(), n.Hour(), n.Minute(), n.Second()),
117 Value: n.Unix(),
118 },
119 }
120 for _, tc := range testcases {
121 t.Run(tc.Name, func(subT *testing.T) {
122 jn, err := parseExpiresOn(tc.String)
123 if err != nil {
124 subT.Error(err)
125 }
126 i, err := jn.Int64()
127 if err != nil {
128 subT.Error(err)
129 }
130 if i != tc.Value {
131 subT.Logf("expected %d, got %d", tc.Value, i)
132 subT.Fail()
133 }
134 })
135 }
136 }
137
138 func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
139 spt := newServicePrincipalToken()
140
141 if !spt.inner.AutoRefresh {
142 t.Fatal("adal: ServicePrincipalToken did not default to automatic token refreshing")
143 }
144
145 spt.SetAutoRefresh(false)
146 if spt.inner.AutoRefresh {
147 t.Fatal("adal: ServicePrincipalToken#SetAutoRefresh did not disable automatic token refreshing")
148 }
149 }
150
151 func TestServicePrincipalTokenSetCustomRefreshFunc(t *testing.T) {
152 spt := newServicePrincipalToken()
153
154 var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) {
155 return nil, nil
156 }
157
158 if spt.customRefreshFunc != nil {
159 t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc had a default custom refresh func when it shouldn't")
160 }
161
162 spt.SetCustomRefreshFunc(refreshFunc)
163
164 if spt.customRefreshFunc == nil {
165 t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc didn't have a refresh func")
166 }
167 }
168
169 func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) {
170 spt := newServicePrincipalToken()
171
172 if spt.inner.RefreshWithin != defaultRefresh {
173 t.Fatal("adal: ServicePrincipalToken did not correctly set the default refresh interval")
174 }
175
176 spt.SetRefreshWithin(2 * defaultRefresh)
177 if spt.inner.RefreshWithin != 2*defaultRefresh {
178 t.Fatal("adal: ServicePrincipalToken#SetRefreshWithin did not set the refresh interval")
179 }
180 }
181
182 func TestServicePrincipalTokenSetSender(t *testing.T) {
183 spt := newServicePrincipalToken()
184
185 c := &http.Client{}
186 spt.SetSender(c)
187 if !reflect.DeepEqual(c, spt.sender) {
188 t.Fatal("adal: ServicePrincipalToken#SetSender did not set the sender")
189 }
190 }
191
192 func TestServicePrincipalTokenRefreshUsesCustomRefreshFunc(t *testing.T) {
193 spt := newServicePrincipalToken()
194
195 called := false
196 var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) {
197 called = true
198 return &Token{}, nil
199 }
200 spt.SetCustomRefreshFunc(refreshFunc)
201 if called {
202 t.Fatalf("adal: ServicePrincipalToken#refreshInternal called the refresh function prior to refreshing")
203 }
204
205 spt.refreshInternal(context.Background(), "https://example.com")
206
207 if !called {
208 t.Fatalf("adal: ServicePrincipalToken#refreshInternal didn't call the refresh function")
209 }
210 }
211
212 func TestFederatedTokenRefreshUsesJwtCallback(t *testing.T) {
213 baseDir, err := os.MkdirTemp("", "")
214 assert.NoError(t, err)
215 jwtFile := filepath.Join(baseDir, "token")
216
217 jwtCallback := func() (string, error) {
218 jwt, err := os.ReadFile(jwtFile)
219 if err != nil {
220 return "", fmt.Errorf("failed to read a file with a federated token: %w", err)
221 }
222 return string(jwt), nil
223 }
224
225 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
226 jwt := r.FormValue("client_assertion")
227 refreshToken := r.FormValue("refresh_token")
228
229 if jwt == "aaa.aaa" {
230 w.Write([]byte(`{"access_token":"A","expires_in":"3600"}`))
231 } else if jwt == "bbb.bbb" {
232 w.Write([]byte(`{"access_token":"B","expires_in":"3600","refresh_token":"R"}`))
233 } else if refreshToken == "R" {
234 w.Write([]byte(`{"access_token":"C","expires_in":"3600"}`))
235 } else {
236 w.WriteHeader(http.StatusBadRequest)
237 }
238 }))
239
240 spt := newServicePrincipalTokenFederatedJwtCallback(t, jwtCallback, server.URL)
241
242
243 err = spt.refreshInternal(context.Background(), "")
244 assert.ErrorIs(t, err, os.ErrNotExist)
245
246
247 err = os.WriteFile(jwtFile, []byte("aaa.aaa"), 0600)
248 assert.NoError(t, err)
249 err = spt.refreshInternal(context.Background(), "")
250 assert.NoError(t, err)
251 assert.Equal(t, "A", spt.inner.Token.AccessToken)
252
253
254 err = os.WriteFile(jwtFile, []byte("bbb.bbb"), 0600)
255 assert.NoError(t, err)
256 err = spt.refreshInternal(context.Background(), "")
257 assert.NoError(t, err)
258 assert.Equal(t, "B", spt.inner.Token.AccessToken)
259
260 assert.Equal(t, "R", spt.inner.Token.RefreshToken)
261
262
263 err = spt.refreshInternal(context.Background(), "")
264 assert.NoError(t, err)
265 assert.Equal(t, "C", spt.inner.Token.AccessToken)
266 }
267
268 func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) {
269 spt := newServicePrincipalToken()
270
271 body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
272 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
273
274 c := mocks.NewSender()
275 s := DecorateSender(c,
276 (func() SendDecorator {
277 return func(s Sender) Sender {
278 return SenderFunc(func(r *http.Request) (*http.Response, error) {
279 if r.Method != "POST" {
280 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "POST", r.Method)
281 }
282 return resp, nil
283 })
284 }
285 })())
286 spt.SetSender(s)
287 err := spt.Refresh()
288 if err != nil {
289 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
290 }
291
292 if body.IsOpen() {
293 t.Fatalf("the response was not closed!")
294 }
295 }
296
297 func TestNewServicePrincipalTokenFromManagedIdentity(t *testing.T) {
298 spt, err := NewServicePrincipalTokenFromManagedIdentity("https://resource", nil)
299 if err != nil {
300 t.Fatalf("Failed to get MSI SPT: %v", err)
301 }
302
303 body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
304 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
305
306 c := mocks.NewSender()
307 s := DecorateSender(c,
308 (func() SendDecorator {
309 return func(s Sender) Sender {
310 return SenderFunc(func(r *http.Request) (*http.Response, error) {
311 if r.Method != "GET" {
312 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
313 }
314 if h := r.Header.Get("Metadata"); h != "true" {
315 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Metadata header for MSI")
316 }
317 return resp, nil
318 })
319 }
320 })())
321 spt.SetSender(s)
322 err = spt.Refresh()
323 if err != nil {
324 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
325 }
326
327 if body.IsOpen() {
328 t.Fatalf("the response was not closed!")
329 }
330 }
331
332 func TestServicePrincipalTokenFromMSICloudshell(t *testing.T) {
333 os.Setenv(msiEndpointEnv, "http://dummy")
334 defer func() {
335 os.Unsetenv(msiEndpointEnv)
336 }()
337 spt, err := NewServicePrincipalTokenFromMSI("", "https://resource")
338 if err != nil {
339 t.Fatalf("Failed to get MSI SPT: %v", err)
340 }
341
342 body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
343 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
344
345 c := mocks.NewSender()
346 s := DecorateSender(c,
347 (func() SendDecorator {
348 return func(s Sender) Sender {
349 return SenderFunc(func(r *http.Request) (*http.Response, error) {
350 if r.Method != http.MethodPost {
351 t.Fatalf("adal: cloudshell did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
352 }
353 if h := r.Header.Get("Metadata"); h != "true" {
354 t.Fatalf("adal: cloudshell did not correctly set Metadata header")
355 }
356 if h := r.Header.Get("Content-Type"); h != "application/x-www-form-urlencoded" {
357 t.Fatalf("adal: cloudshell did not correctly set Content-Type header")
358 }
359 return resp, nil
360 })
361 }
362 })())
363 spt.SetSender(s)
364 err = spt.Refresh()
365 if err != nil {
366 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
367 }
368
369 if body.IsOpen() {
370 t.Fatalf("the response was not closed!")
371 }
372 }
373
374 func TestServicePrincipalTokenFromMSIRefreshZeroRetry(t *testing.T) {
375 resource := "https://resource"
376 cb := func(token Token) error { return nil }
377
378 endpoint, _ := GetMSIVMEndpoint()
379 spt, err := NewServicePrincipalTokenFromMSI(endpoint, resource, cb)
380 if err != nil {
381 t.Fatalf("Failed to get MSI SPT: %v", err)
382 }
383 spt.MaxMSIRefreshAttempts = 1
384
385 body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
386 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
387
388 c := mocks.NewSender()
389 s := DecorateSender(c,
390 (func() SendDecorator {
391 return func(s Sender) Sender {
392 return SenderFunc(func(r *http.Request) (*http.Response, error) {
393
394 if r.Method != "GET" {
395 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
396 }
397 if h := r.Header.Get("Metadata"); h != "true" {
398 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Metadata header for MSI")
399 }
400 return resp, nil
401 })
402 }
403 })())
404 spt.SetSender(s)
405 err = spt.Refresh()
406 if err != nil {
407 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
408 }
409
410 if body.IsOpen() {
411 t.Fatalf("the response was not closed!")
412 }
413 }
414
415 func TestServicePrincipalTokenFromASE(t *testing.T) {
416 os.Setenv("MSI_ENDPOINT", "http://localhost")
417 os.Setenv("MSI_SECRET", "super")
418 defer func() {
419 os.Unsetenv("MSI_ENDPOINT")
420 os.Unsetenv("MSI_SECRET")
421 }()
422 resource := "https://resource"
423 spt, err := NewServicePrincipalTokenFromMSI("", resource)
424 if err != nil {
425 t.Fatalf("Failed to get MSI SPT: %v", err)
426 }
427 spt.MaxMSIRefreshAttempts = 1
428
429 nowTime := time.Now()
430 expiresOn := nowTime.UTC().Add(time.Hour)
431
432 body := mocks.NewBody(newTokenJSON("3600", expiresOn.Format(expiresOnDateFormat), "test"))
433 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
434
435 c := mocks.NewSender()
436 s := DecorateSender(c,
437 (func() SendDecorator {
438 return func(s Sender) Sender {
439 return SenderFunc(func(r *http.Request) (*http.Response, error) {
440 if r.Method != "GET" {
441 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
442 }
443 if h := r.Header.Get(metadataHeader); h != "" {
444 t.Fatalf("adal: ServicePrincipalToken#Refresh incorrectly set Metadata header for ASE")
445 }
446 if s := r.Header.Get(secretHeader); s != "super" {
447 t.Fatalf("adal: unexpected secret header value %s", s)
448 }
449 if r.URL.Host != "localhost" {
450 t.Fatalf("adal: unexpected host %s", r.URL.Host)
451 }
452 qp := r.URL.Query()
453 if api := qp.Get("api-version"); api != appServiceAPIVersion2017 {
454 t.Fatalf("adal: unexpected api-version %s", api)
455 }
456 return resp, nil
457 })
458 }
459 })())
460 spt.SetSender(s)
461 err = spt.Refresh()
462 if err != nil {
463 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
464 }
465 v, err := spt.inner.Token.ExpiresOn.Int64()
466 if err != nil {
467 t.Fatalf("adal: failed to get ExpiresOn %v", err)
468 }
469 if nowAsUnix := nowTime.Add(time.Hour).Unix(); v != nowAsUnix {
470 t.Fatalf("adal: expected %v, got %v", nowAsUnix, v)
471 }
472 if body.IsOpen() {
473 t.Fatalf("the response was not closed!")
474 }
475 }
476
477 func TestServicePrincipalTokenFromADFS(t *testing.T) {
478 os.Setenv("MSI_ENDPOINT", "http://localhost")
479 os.Setenv("MSI_SECRET", "super")
480 defer func() {
481 os.Unsetenv("MSI_ENDPOINT")
482 os.Unsetenv("MSI_SECRET")
483 }()
484 resource := "https://resource"
485 endpoint, _ := GetMSIEndpoint()
486 spt, err := NewServicePrincipalTokenFromMSI(endpoint, resource)
487 if err != nil {
488 t.Fatalf("Failed to get MSI SPT: %v", err)
489 }
490 spt.MaxMSIRefreshAttempts = 1
491 const expiresIn = 3600
492 body := mocks.NewBody(newADFSTokenJSON(expiresIn))
493 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
494
495 c := mocks.NewSender()
496 s := DecorateSender(c,
497 (func() SendDecorator {
498 return func(s Sender) Sender {
499 return SenderFunc(func(r *http.Request) (*http.Response, error) {
500 if r.Method != "GET" {
501 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
502 }
503 if h := r.Header.Get(metadataHeader); h != "" {
504 t.Fatalf("adal: ServicePrincipalToken#Refresh incorrectly set Metadata header for ASE")
505 }
506 if s := r.Header.Get(secretHeader); s != "super" {
507 t.Fatalf("adal: unexpected secret header value %s", s)
508 }
509 if r.URL.Host != "localhost" {
510 t.Fatalf("adal: unexpected host %s", r.URL.Host)
511 }
512 qp := r.URL.Query()
513 if api := qp.Get("api-version"); api != appServiceAPIVersion2017 {
514 t.Fatalf("adal: unexpected api-version %s", api)
515 }
516 return resp, nil
517 })
518 }
519 })())
520 spt.SetSender(s)
521 err = spt.Refresh()
522 if err != nil {
523 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
524 }
525 i, err := spt.inner.Token.ExpiresIn.Int64()
526 if err != nil {
527 t.Fatalf("unexpected parsing of expires_in: %v", err)
528 }
529 if i != expiresIn {
530 t.Fatalf("unexpected expires_in %d", i)
531 }
532 if spt.inner.Token.ExpiresOn.String() != "" {
533 t.Fatal("expected empty expires_on")
534 }
535 if body.IsOpen() {
536 t.Fatalf("the response was not closed!")
537 }
538 }
539
540 func TestServicePrincipalTokenFromMSIRefreshCancel(t *testing.T) {
541 ctx, cancel := context.WithCancel(context.Background())
542 endpoint, _ := GetMSIVMEndpoint()
543
544 spt, err := NewServicePrincipalTokenFromMSI(endpoint, "https://resource")
545 if err != nil {
546 t.Fatalf("Failed to get MSI SPT: %v", err)
547 }
548
549 c := mocks.NewSender()
550 c.AppendAndRepeatResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError), 5)
551
552 var wg sync.WaitGroup
553 wg.Add(1)
554 start := time.Now()
555 end := time.Now()
556
557 go func() {
558 spt.SetSender(c)
559 err = spt.RefreshWithContext(ctx)
560 end = time.Now()
561 wg.Done()
562 }()
563
564 cancel()
565 wg.Wait()
566 time.Sleep(5 * time.Millisecond)
567
568 if end.Sub(start) >= time.Second {
569 t.Fatalf("TestServicePrincipalTokenFromMSIRefreshCancel failed to cancel")
570 }
571 }
572
573 func TestServicePrincipalTokenRefreshSetsMimeType(t *testing.T) {
574 spt := newServicePrincipalToken()
575
576 body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
577 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
578
579 c := mocks.NewSender()
580 s := DecorateSender(c,
581 (func() SendDecorator {
582 return func(s Sender) Sender {
583 return SenderFunc(func(r *http.Request) (*http.Response, error) {
584 if r.Header.Get(http.CanonicalHeaderKey("Content-Type")) != "application/x-www-form-urlencoded" {
585 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set Content-Type -- expected %v, received %v",
586 "application/x-form-urlencoded",
587 r.Header.Get(http.CanonicalHeaderKey("Content-Type")))
588 }
589 return resp, nil
590 })
591 }
592 })())
593 spt.SetSender(s)
594 err := spt.Refresh()
595 if err != nil {
596 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
597 }
598 }
599
600 func TestServicePrincipalTokenRefreshSetsURL(t *testing.T) {
601 spt := newServicePrincipalToken()
602
603 body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
604 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
605
606 c := mocks.NewSender()
607 s := DecorateSender(c,
608 (func() SendDecorator {
609 return func(s Sender) Sender {
610 return SenderFunc(func(r *http.Request) (*http.Response, error) {
611 if r.URL.String() != TestOAuthConfig.TokenEndpoint.String() {
612 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the URL -- expected %v, received %v",
613 TestOAuthConfig.TokenEndpoint, r.URL)
614 }
615 return resp, nil
616 })
617 }
618 })())
619 spt.SetSender(s)
620 err := spt.Refresh()
621 if err != nil {
622 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
623 }
624 }
625
626 func testServicePrincipalTokenRefreshSetsBody(t *testing.T, spt *ServicePrincipalToken, f func(*testing.T, []byte)) {
627 body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
628 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
629
630 c := mocks.NewSender()
631 s := DecorateSender(c,
632 (func() SendDecorator {
633 return func(s Sender) Sender {
634 return SenderFunc(func(r *http.Request) (*http.Response, error) {
635 b, err := ioutil.ReadAll(r.Body)
636 if err != nil {
637 t.Fatalf("adal: Failed to read body of Service Principal token request (%v)", err)
638 }
639 f(t, b)
640 return resp, nil
641 })
642 }
643 })())
644 spt.SetSender(s)
645 err := spt.Refresh()
646 if err != nil {
647 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
648 }
649 }
650
651 func TestServicePrincipalTokenManualRefreshSetsBody(t *testing.T) {
652 sptManual := newServicePrincipalTokenManual()
653 testServicePrincipalTokenRefreshSetsBody(t, sptManual, func(t *testing.T, b []byte) {
654 if string(b) != defaultManualFormData {
655 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
656 defaultManualFormData, string(b))
657 }
658 })
659 }
660
661 func TestServicePrincipalTokenCertficateRefreshSetsBody(t *testing.T) {
662 sptCert := newServicePrincipalTokenCertificate(t)
663 testServicePrincipalTokenRefreshSetsBody(t, sptCert, func(t *testing.T, b []byte) {
664 body := string(b)
665
666 values, _ := url.ParseQuery(body)
667 if values["client_assertion_type"][0] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ||
668 values["client_id"][0] != "id" ||
669 values["grant_type"][0] != "client_credentials" ||
670 values["resource"][0] != "resource" {
671 t.Fatalf("adal: ServicePrincipalTokenCertificate#Refresh did not correctly set the HTTP Request Body.")
672 }
673
674 tok, _ := jwt.Parse(values["client_assertion"][0], nil)
675 if tok == nil {
676 t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to be a JWT")
677 }
678 if _, ok := tok.Header["x5t"]; !ok {
679 t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5t header")
680 }
681 if _, ok := tok.Header["x5c"]; !ok {
682 t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an x5c header")
683 }
684 claims, ok := tok.Claims.(jwt.MapClaims)
685 if !ok {
686 t.Fatalf("expected MapClaims, got %T", tok.Claims)
687 }
688 if err := claims.Valid(); err != nil {
689 t.Fatalf("invalid claim: %v", err)
690 }
691 if aud := claims["aud"]; aud != "https://login.test.com/SomeTenantID/oauth2/token?api-version=1.0" {
692 t.Fatalf("unexpected aud: %s", aud)
693 }
694 if iss := claims["iss"]; iss != "id" {
695 t.Fatalf("unexpected iss: %s", iss)
696 }
697 if sub := claims["sub"]; sub != "id" {
698 t.Fatalf("unexpected sub: %s", sub)
699 }
700 })
701 }
702
703 func TestServicePrincipalTokenUsernamePasswordRefreshSetsBody(t *testing.T) {
704 spt := newServicePrincipalTokenUsernamePassword(t)
705 testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
706 body := string(b)
707
708 values, _ := url.ParseQuery(body)
709 if values["client_id"][0] != "id" ||
710 values["grant_type"][0] != "password" ||
711 values["username"][0] != "username" ||
712 values["password"][0] != "password" ||
713 values["resource"][0] != "resource" {
714 t.Fatalf("adal: ServicePrincipalTokenUsernamePassword#Refresh did not correctly set the HTTP Request Body.")
715 }
716 })
717 }
718
719 func TestServicePrincipalTokenAuthorizationCodeRefreshSetsBody(t *testing.T) {
720 spt := newServicePrincipalTokenAuthorizationCode(t)
721 testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
722 body := string(b)
723
724 values, _ := url.ParseQuery(body)
725 if values["client_id"][0] != "id" ||
726 values["grant_type"][0] != OAuthGrantTypeAuthorizationCode ||
727 values["code"][0] != "code" ||
728 values["client_secret"][0] != "clientSecret" ||
729 values["redirect_uri"][0] != "http://redirectUri/getToken" ||
730 values["resource"][0] != "resource" {
731 t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.")
732 }
733 })
734 testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
735 body := string(b)
736
737 values, _ := url.ParseQuery(body)
738 if values["client_id"][0] != "id" ||
739 values["grant_type"][0] != OAuthGrantTypeRefreshToken ||
740 values["code"][0] != "code" ||
741 values["client_secret"][0] != "clientSecret" ||
742 values["redirect_uri"][0] != "http://redirectUri/getToken" ||
743 values["resource"][0] != "resource" {
744 t.Fatalf("adal: ServicePrincipalTokenAuthorizationCode#Refresh did not correctly set the HTTP Request Body.")
745 }
746 })
747 }
748
749 func TestServicePrincipalTokenSecretRefreshSetsBody(t *testing.T) {
750 spt := newServicePrincipalToken()
751 testServicePrincipalTokenRefreshSetsBody(t, spt, func(t *testing.T, b []byte) {
752 if string(b) != defaultFormData {
753 t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set the HTTP Request Body -- expected %v, received %v",
754 defaultFormData, string(b))
755 }
756
757 })
758 }
759
760 func TestServicePrincipalTokenFederatedJwtRefreshSetsBody(t *testing.T) {
761 sptCert := newServicePrincipalTokenFederatedJwt(t)
762 testServicePrincipalTokenRefreshSetsBody(t, sptCert, func(t *testing.T, b []byte) {
763 body := string(b)
764
765 values, _ := url.ParseQuery(body)
766 if values["client_assertion_type"][0] != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ||
767 values["client_id"][0] != "id" ||
768 values["grant_type"][0] != "client_credentials" ||
769 values["resource"][0] != "resource" {
770 t.Fatalf("adal: ServicePrincipalTokenCertificate#Refresh did not correctly set the HTTP Request Body.")
771 }
772
773 tok, _ := jwt.Parse(values["client_assertion"][0], nil)
774 if tok == nil {
775 t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to be a JWT")
776 }
777 if _, ok := tok.Header["typ"]; !ok {
778 t.Fatalf("adal: ServicePrincipalTokenCertificate#Expected client_assertion to have an typ header")
779 }
780
781 claims, ok := tok.Claims.(jwt.MapClaims)
782 if !ok {
783 t.Fatalf("expected MapClaims, got %T", tok.Claims)
784 }
785 if err := claims.Valid(); err != nil {
786 t.Fatalf("invalid claim: %v", err)
787 }
788 if aud := claims["aud"]; aud != "testAudience" {
789 t.Fatalf("unexpected aud: %s", aud)
790 }
791 if iss := claims["iss"]; iss != "id" {
792 t.Fatalf("unexpected iss: %s", iss)
793 }
794 if sub := claims["sub"]; sub != "id" {
795 t.Fatalf("unexpected sub: %s", sub)
796 }
797 })
798 }
799
800 func TestServicePrincipalTokenRefreshClosesRequestBody(t *testing.T) {
801 spt := newServicePrincipalToken()
802
803 body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
804 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
805
806 c := mocks.NewSender()
807 s := DecorateSender(c,
808 (func() SendDecorator {
809 return func(s Sender) Sender {
810 return SenderFunc(func(r *http.Request) (*http.Response, error) {
811 return resp, nil
812 })
813 }
814 })())
815 spt.SetSender(s)
816 err := spt.Refresh()
817 if err != nil {
818 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
819 }
820 if resp.Body.(*mocks.Body).IsOpen() {
821 t.Fatal("adal: ServicePrincipalToken#Refresh failed to close the HTTP Response Body")
822 }
823 }
824
825 func TestServicePrincipalTokenRefreshRejectsResponsesWithStatusNotOK(t *testing.T) {
826 spt := newServicePrincipalToken()
827
828 body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
829 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusUnauthorized, "Unauthorized")
830
831 c := mocks.NewSender()
832 s := DecorateSender(c,
833 (func() SendDecorator {
834 return func(s Sender) Sender {
835 return SenderFunc(func(r *http.Request) (*http.Response, error) {
836 return resp, nil
837 })
838 }
839 })())
840 spt.SetSender(s)
841 err := spt.Refresh()
842 if err == nil {
843 t.Fatalf("adal: ServicePrincipalToken#Refresh should reject a response with status != %d", http.StatusOK)
844 }
845 }
846
847 func TestServicePrincipalTokenRefreshRejectsEmptyBody(t *testing.T) {
848 spt := newServicePrincipalToken()
849
850 c := mocks.NewSender()
851 s := DecorateSender(c,
852 (func() SendDecorator {
853 return func(s Sender) Sender {
854 return SenderFunc(func(r *http.Request) (*http.Response, error) {
855 return mocks.NewResponse(), nil
856 })
857 }
858 })())
859 spt.SetSender(s)
860 err := spt.Refresh()
861 if err == nil {
862 t.Fatal("adal: ServicePrincipalToken#Refresh should reject an empty token")
863 }
864 }
865
866 func TestServicePrincipalTokenRefreshPropagatesErrors(t *testing.T) {
867 spt := newServicePrincipalToken()
868
869 c := mocks.NewSender()
870 c.SetError(fmt.Errorf("Faux Error"))
871 spt.SetSender(c)
872
873 err := spt.Refresh()
874 if err == nil {
875 t.Fatal("adal: Failed to propagate the request error")
876 }
877 }
878
879 func TestServicePrincipalTokenRefreshReturnsErrorIfNotOk(t *testing.T) {
880 spt := newServicePrincipalToken()
881
882 c := mocks.NewSender()
883 c.AppendResponse(mocks.NewResponseWithStatus("401 NotAuthorized", http.StatusUnauthorized))
884 spt.SetSender(c)
885
886 err := spt.Refresh()
887 if err == nil {
888 t.Fatalf("adal: Failed to return an when receiving a status code other than HTTP %d", http.StatusOK)
889 }
890 }
891
892 func TestServicePrincipalTokenRefreshUnmarshals(t *testing.T) {
893 spt := newServicePrincipalToken()
894
895 expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
896 j := newTokenJSON(`"3600"`, expiresOn, "resource")
897 resp := mocks.NewResponseWithContent(j)
898 c := mocks.NewSender()
899 s := DecorateSender(c,
900 (func() SendDecorator {
901 return func(s Sender) Sender {
902 return SenderFunc(func(r *http.Request) (*http.Response, error) {
903 return resp, nil
904 })
905 }
906 })())
907 spt.SetSender(s)
908
909 err := spt.Refresh()
910 if err != nil {
911 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
912 } else if spt.inner.Token.AccessToken != "accessToken" ||
913 spt.inner.Token.ExpiresIn != "3600" ||
914 spt.inner.Token.ExpiresOn != json.Number(expiresOn) ||
915 spt.inner.Token.NotBefore != json.Number(expiresOn) ||
916 spt.inner.Token.Resource != "resource" ||
917 spt.inner.Token.Type != "Bearer" {
918 t.Fatalf("adal: ServicePrincipalToken#Refresh failed correctly unmarshal the JSON -- expected %v, received %v",
919 j, *spt)
920 }
921 }
922
923 func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) {
924 spt := newServicePrincipalToken()
925 expireToken(&spt.inner.Token)
926
927 body := mocks.NewBody(newTokenJSON(`"3600"`, "12345", "test"))
928 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
929
930 f := false
931 c := mocks.NewSender()
932 s := DecorateSender(c,
933 (func() SendDecorator {
934 return func(s Sender) Sender {
935 return SenderFunc(func(r *http.Request) (*http.Response, error) {
936 f = true
937 return resp, nil
938 })
939 }
940 })())
941 spt.SetSender(s)
942 err := spt.EnsureFresh()
943 if err != nil {
944 t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
945 }
946 if !f {
947 t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token")
948 }
949 }
950
951 func TestServicePrincipalTokenEnsureFreshWithIntExpiresOn(t *testing.T) {
952 spt := newServicePrincipalToken()
953 expireToken(&spt.inner.Token)
954
955 body := mocks.NewBody(newTokenJSONIntExpiresOn(`"3600"`, 12345, "test"))
956 resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")
957
958 f := false
959 c := mocks.NewSender()
960 s := DecorateSender(c,
961 (func() SendDecorator {
962 return func(s Sender) Sender {
963 return SenderFunc(func(r *http.Request) (*http.Response, error) {
964 f = true
965 return resp, nil
966 })
967 }
968 })())
969 spt.SetSender(s)
970 err := spt.EnsureFresh()
971 if err != nil {
972 t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
973 }
974 if !f {
975 t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token")
976 }
977 }
978
979 func TestServicePrincipalTokenEnsureFreshFails1(t *testing.T) {
980 spt := newServicePrincipalToken()
981 expireToken(&spt.inner.Token)
982
983 c := mocks.NewSender()
984 c.SetError(fmt.Errorf("some failure"))
985
986 spt.SetSender(c)
987 err := spt.EnsureFresh()
988 if err == nil {
989 t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return an error")
990 }
991 if _, ok := err.(TokenRefreshError); ok {
992 t.Fatal("adal: ServicePrincipalToken#EnsureFresh unexpected TokenRefreshError")
993 }
994 }
995
996 func TestServicePrincipalTokenEnsureFreshFails2(t *testing.T) {
997 spt := newServicePrincipalToken()
998 expireToken(&spt.inner.Token)
999
1000 c := mocks.NewSender()
1001 c.AppendResponse(mocks.NewResponseWithStatus("bad request", http.StatusBadRequest))
1002
1003 spt.SetSender(c)
1004 err := spt.EnsureFresh()
1005 if err == nil {
1006 t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return an error")
1007 }
1008 if _, ok := err.(TokenRefreshError); !ok {
1009 t.Fatal("adal: ServicePrincipalToken#EnsureFresh didn't return a TokenRefreshError")
1010 }
1011 }
1012
1013 func TestServicePrincipalTokenEnsureFreshSkipsIfFresh(t *testing.T) {
1014 spt := newServicePrincipalToken()
1015 setTokenToExpireIn(&spt.inner.Token, 1000*time.Second)
1016
1017 f := false
1018 c := mocks.NewSender()
1019 s := DecorateSender(c,
1020 (func() SendDecorator {
1021 return func(s Sender) Sender {
1022 return SenderFunc(func(r *http.Request) (*http.Response, error) {
1023 f = true
1024 return mocks.NewResponse(), nil
1025 })
1026 }
1027 })())
1028 spt.SetSender(s)
1029 err := spt.EnsureFresh()
1030 if err != nil {
1031 t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err)
1032 }
1033 if f {
1034 t.Fatal("adal: ServicePrincipalToken#EnsureFresh invoked Refresh for fresh token")
1035 }
1036 }
1037
1038 func TestRefreshCallback(t *testing.T) {
1039 callbackTriggered := false
1040 spt := newServicePrincipalToken(func(Token) error {
1041 callbackTriggered = true
1042 return nil
1043 })
1044
1045 expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
1046
1047 sender := mocks.NewSender()
1048 j := newTokenJSON(`"3600"`, expiresOn, "resource")
1049 sender.AppendResponse(mocks.NewResponseWithContent(j))
1050 spt.SetSender(sender)
1051 err := spt.Refresh()
1052 if err != nil {
1053 t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
1054 }
1055 if !callbackTriggered {
1056 t.Fatalf("adal: RefreshCallback failed to trigger call callback")
1057 }
1058 }
1059
1060 func TestRefreshCallbackErrorPropagates(t *testing.T) {
1061 errorText := "this is an error text"
1062 spt := newServicePrincipalToken(func(Token) error {
1063 return fmt.Errorf(errorText)
1064 })
1065
1066 expiresOn := strconv.Itoa(int(time.Now().Add(3600 * time.Second).Sub(date.UnixEpoch()).Seconds()))
1067
1068 sender := mocks.NewSender()
1069 j := newTokenJSON(`"3600"`, expiresOn, "resource")
1070 sender.AppendResponse(mocks.NewResponseWithContent(j))
1071 spt.SetSender(sender)
1072 err := spt.Refresh()
1073
1074 if err == nil || !strings.Contains(err.Error(), errorText) {
1075 t.Fatalf("adal: RefreshCallback failed to propagate error")
1076 }
1077 }
1078
1079
1080 func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) {
1081 spt := newServicePrincipalTokenManual()
1082 spt.inner.Token.RefreshToken = ""
1083 err := spt.Refresh()
1084 if err == nil {
1085 t.Fatalf("adal: ServicePrincipalToken#Refresh should have failed with a ManualTokenSecret without a refresh token")
1086 }
1087 }
1088
1089 func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
1090 const resource = "https://resource"
1091 cb := func(token Token) error { return nil }
1092
1093 spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
1094 if err != nil {
1095 t.Fatalf("Failed to get MSI SPT: %v", err)
1096 }
1097
1098
1099 if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
1100 t.Fatal("SPT secret was not of MSI type")
1101 }
1102
1103 if spt.inner.Resource != resource {
1104 t.Fatal("SPT came back with incorrect resource")
1105 }
1106
1107 if len(spt.refreshCallbacks) != 1 {
1108 t.Fatal("SPT had incorrect refresh callbacks.")
1109 }
1110 }
1111
1112 func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
1113 const (
1114 resource = "https://resource"
1115 userID = "abc123"
1116 )
1117 cb := func(token Token) error { return nil }
1118
1119 spt, err := NewServicePrincipalTokenFromMSIWithUserAssignedID("http://msiendpoint/", resource, userID, cb)
1120 if err != nil {
1121 t.Fatalf("Failed to get MSI SPT: %v", err)
1122 }
1123
1124
1125 if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
1126 t.Fatal("SPT secret was not of MSI type")
1127 }
1128
1129 if spt.inner.Resource != resource {
1130 t.Fatal("SPT came back with incorrect resource")
1131 }
1132
1133 if len(spt.refreshCallbacks) != 1 {
1134 t.Fatal("SPT had incorrect refresh callbacks.")
1135 }
1136
1137 if spt.inner.ClientID != userID {
1138 t.Fatal("SPT had incorrect client ID")
1139 }
1140 }
1141
1142 func TestNewServicePrincipalTokenFromMSIWithIdentityResourceID(t *testing.T) {
1143 const (
1144 resource = "https://resource"
1145 identityResourceID = "/subscriptions/testSub/resourceGroups/testGroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test-identity"
1146 )
1147 cb := func(token Token) error { return nil }
1148
1149 spt, err := NewServicePrincipalTokenFromMSIWithIdentityResourceID("http://msiendpoint/", resource, identityResourceID, cb)
1150 if err != nil {
1151 t.Fatalf("Failed to get MSI SPT: %v", err)
1152 }
1153
1154
1155 if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
1156 t.Fatal("SPT secret was not of MSI type")
1157 }
1158
1159 if spt.inner.Resource != resource {
1160 t.Fatal("SPT came back with incorrect resource")
1161 }
1162
1163 if len(spt.refreshCallbacks) != 1 {
1164 t.Fatal("SPT had incorrect refresh callbacks.")
1165 }
1166
1167 urlPathParameter := url.Values{}
1168 urlPathParameter.Set("mi_res_id", identityResourceID)
1169
1170 if !strings.Contains(spt.inner.OauthConfig.TokenEndpoint.RawQuery, urlPathParameter.Encode()) {
1171 t.Fatal("SPT tokenEndpoint should contains mi_res_id")
1172 }
1173 }
1174
1175 func TestNewServicePrincipalTokenFromManualTokenSecret(t *testing.T) {
1176 token := newToken()
1177 secret := &ServicePrincipalAuthorizationCodeSecret{
1178 ClientSecret: "clientSecret",
1179 AuthorizationCode: "code123",
1180 RedirectURI: "redirect",
1181 }
1182
1183 spt, err := NewServicePrincipalTokenFromManualTokenSecret(TestOAuthConfig, "id", "resource", token, secret, nil)
1184 if err != nil {
1185 t.Fatalf("Failed creating new SPT: %s", err)
1186 }
1187
1188 if !reflect.DeepEqual(token, spt.inner.Token) {
1189 t.Fatalf("Tokens do not match: %s, %s", token, spt.inner.Token)
1190 }
1191
1192 if !reflect.DeepEqual(secret, spt.inner.Secret) {
1193 t.Fatalf("Secrets do not match: %s, %s", secret, spt.inner.Secret)
1194 }
1195
1196 }
1197
1198 func TestGetVMEndpoint(t *testing.T) {
1199 endpoint, err := GetMSIVMEndpoint()
1200 if err != nil {
1201 t.Fatal("Coudn't get VM endpoint")
1202 }
1203
1204 if endpoint != msiEndpoint {
1205 t.Fatal("Didn't get correct endpoint")
1206 }
1207 }
1208
1209 func TestGetAppServiceEndpoint(t *testing.T) {
1210 const testEndpoint = "http://172.16.1.2:8081/msi/token"
1211 const aseSecret = "the_secret"
1212 if err := os.Setenv(msiEndpointEnv, testEndpoint); err != nil {
1213 t.Fatalf("os.Setenv: %v", err)
1214 }
1215 if err := os.Setenv(msiSecretEnv, aseSecret); err != nil {
1216 t.Fatalf("os.Setenv: %v", err)
1217 }
1218 defer func() {
1219 os.Unsetenv(msiEndpointEnv)
1220 os.Unsetenv(msiSecretEnv)
1221 }()
1222
1223 endpoint, err := GetMSIAppServiceEndpoint()
1224 if err != nil {
1225 t.Fatal("Coudn't get App Service endpoint")
1226 }
1227
1228 if endpoint != testEndpoint {
1229 t.Fatal("Didn't get correct endpoint")
1230 }
1231 }
1232
1233 func TestGetMSIEndpoint(t *testing.T) {
1234 const (
1235 testEndpoint = "http://172.16.1.2:8081/msi/token"
1236 testSecret = "DEADBEEF-BBBB-AAAA-DDDD-DDD000000DDD"
1237 )
1238
1239
1240 if err := os.Unsetenv(msiEndpointEnv); err != nil {
1241 t.Fatalf("os.Unsetenv: %v", err)
1242 }
1243
1244 if err := os.Unsetenv(msiSecretEnv); err != nil {
1245 t.Fatalf("os.Unsetenv: %v", err)
1246 }
1247
1248 vmEndpoint, err := GetMSIEndpoint()
1249 if err != nil {
1250 t.Fatal("Coudn't get VM endpoint")
1251 }
1252
1253 if vmEndpoint != msiEndpoint {
1254 t.Fatal("Didn't get correct endpoint")
1255 }
1256
1257
1258 if err := os.Setenv(msiEndpointEnv, testEndpoint); err != nil {
1259 t.Fatalf("os.Setenv: %v", err)
1260 }
1261
1262 if err := os.Setenv(msiSecretEnv, testSecret); err != nil {
1263 t.Fatalf("os.Setenv: %v", err)
1264 }
1265
1266 asEndpoint, err := GetMSIEndpoint()
1267 if err != nil {
1268 t.Fatal("Coudn't get App Service endpoint")
1269 }
1270
1271 if asEndpoint != testEndpoint {
1272 t.Fatal("Didn't get correct endpoint")
1273 }
1274
1275 if err := os.Unsetenv(msiEndpointEnv); err != nil {
1276 t.Fatalf("os.Unsetenv: %v", err)
1277 }
1278
1279 if err := os.Unsetenv(msiSecretEnv); err != nil {
1280 t.Fatalf("os.Unsetenv: %v", err)
1281 }
1282 }
1283
1284 func TestClientSecretWithASESet(t *testing.T) {
1285 if err := os.Setenv(msiEndpointEnv, "http://172.16.1.2:8081/msi/token"); err != nil {
1286 t.Fatalf("os.Setenv: %v", err)
1287 }
1288 if err := os.Setenv(msiSecretEnv, "the_secret"); err != nil {
1289 t.Fatalf("os.Setenv: %v", err)
1290 }
1291 defer func() {
1292 os.Unsetenv(msiEndpointEnv)
1293 os.Unsetenv(msiSecretEnv)
1294 }()
1295 spt := newServicePrincipalToken()
1296 if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
1297 t.Fatal("should not have MSI secret for client secret token even when ASE is enabled")
1298 }
1299 }
1300
1301 func TestMarshalServicePrincipalNoSecret(t *testing.T) {
1302 spt := newServicePrincipalTokenManual()
1303 b, err := json.Marshal(spt)
1304 if err != nil {
1305 t.Fatalf("failed to marshal token: %+v", err)
1306 }
1307 var spt2 *ServicePrincipalToken
1308 err = json.Unmarshal(b, &spt2)
1309 if err != nil {
1310 t.Fatalf("failed to unmarshal token: %+v", err)
1311 }
1312 if !reflect.DeepEqual(spt, spt2) {
1313 t.Fatal("tokens don't match")
1314 }
1315 }
1316
1317 func TestMarshalServicePrincipalTokenSecret(t *testing.T) {
1318 spt := newServicePrincipalToken()
1319 b, err := json.Marshal(spt)
1320 if err != nil {
1321 t.Fatalf("failed to marshal token: %+v", err)
1322 }
1323 var spt2 *ServicePrincipalToken
1324 err = json.Unmarshal(b, &spt2)
1325 if err != nil {
1326 t.Fatalf("failed to unmarshal token: %+v", err)
1327 }
1328 if !reflect.DeepEqual(spt, spt2) {
1329 t.Fatal("tokens don't match")
1330 }
1331 }
1332
1333 func TestMarshalServicePrincipalCertificateSecret(t *testing.T) {
1334 spt := newServicePrincipalTokenCertificate(t)
1335 b, err := json.Marshal(spt)
1336 if err == nil {
1337 t.Fatal("expected error when marshalling certificate token")
1338 }
1339 var spt2 *ServicePrincipalToken
1340 err = json.Unmarshal(b, &spt2)
1341 if err == nil {
1342 t.Fatal("expected error when unmarshalling certificate token")
1343 }
1344 }
1345
1346 func TestMarshalServicePrincipalMSISecret(t *testing.T) {
1347 spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", "", "")
1348 if err != nil {
1349 t.Fatalf("failed to get MSI SPT: %+v", err)
1350 }
1351 b, err := json.Marshal(spt)
1352 if err == nil {
1353 t.Fatal("expected error when marshalling MSI token")
1354 }
1355 var spt2 *ServicePrincipalToken
1356 err = json.Unmarshal(b, &spt2)
1357 if err == nil {
1358 t.Fatal("expected error when unmarshalling MSI token")
1359 }
1360 }
1361
1362 func TestMarshalServicePrincipalUsernamePasswordSecret(t *testing.T) {
1363 spt := newServicePrincipalTokenUsernamePassword(t)
1364 b, err := json.Marshal(spt)
1365 if err != nil {
1366 t.Fatalf("failed to marshal token: %+v", err)
1367 }
1368 var spt2 *ServicePrincipalToken
1369 err = json.Unmarshal(b, &spt2)
1370 if err != nil {
1371 t.Fatalf("failed to unmarshal token: %+v", err)
1372 }
1373 if !reflect.DeepEqual(spt, spt2) {
1374 t.Fatal("tokens don't match")
1375 }
1376 }
1377
1378 func TestMarshalServicePrincipalAuthorizationCodeSecret(t *testing.T) {
1379 spt := newServicePrincipalTokenAuthorizationCode(t)
1380 b, err := json.Marshal(spt)
1381 if err != nil {
1382 t.Fatalf("failed to marshal token: %+v", err)
1383 }
1384 var spt2 *ServicePrincipalToken
1385 err = json.Unmarshal(b, &spt2)
1386 if err != nil {
1387 t.Fatalf("failed to unmarshal token: %+v", err)
1388 }
1389 if !reflect.DeepEqual(spt, spt2) {
1390 t.Fatal("tokens don't match")
1391 }
1392 }
1393
1394 func TestMarshalServicePrincipalFederatedSecret(t *testing.T) {
1395 spt := newServicePrincipalTokenFederatedJwt(t)
1396 b, err := json.Marshal(spt)
1397 if err == nil {
1398 t.Fatal("expected error when marshalling certificate token")
1399 }
1400 var spt2 *ServicePrincipalToken
1401 err = json.Unmarshal(b, &spt2)
1402 if err == nil {
1403 t.Fatal("expected error when unmarshalling certificate token")
1404 }
1405 }
1406
1407 func TestMarshalInnerToken(t *testing.T) {
1408 spt := newServicePrincipalTokenManual()
1409 tokenJSON, err := spt.MarshalTokenJSON()
1410 if err != nil {
1411 t.Fatalf("failed to marshal token: %+v", err)
1412 }
1413
1414 testToken := newToken()
1415 testToken.RefreshToken = "refreshtoken"
1416
1417 testTokenJSON, err := json.Marshal(testToken)
1418 if err != nil {
1419 t.Fatalf("failed to marshal test token: %+v", err)
1420 }
1421
1422 if !reflect.DeepEqual(tokenJSON, testTokenJSON) {
1423 t.Fatalf("tokens don't match: %s, %s", tokenJSON, testTokenJSON)
1424 }
1425
1426 var t1 Token
1427 err = json.Unmarshal(tokenJSON, &t1)
1428 if err != nil {
1429 t.Fatalf("failed to unmarshal token: %+v", err)
1430 }
1431
1432 if !reflect.DeepEqual(t1, testToken) {
1433 t.Fatalf("tokens don't match: %s, %s", t1, testToken)
1434 }
1435 }
1436
1437 func TestNewMultiTenantServicePrincipalToken(t *testing.T) {
1438 cfg, err := NewMultiTenantOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID, TestAuxTenantIDs, OAuthOptions{})
1439 if err != nil {
1440 t.Fatalf("autorest/adal: unexpected error while creating multitenant config: %v", err)
1441 }
1442 mt, err := NewMultiTenantServicePrincipalToken(cfg, "clientID", "superSecret", "resource")
1443 if err != nil {
1444 t.Fatalf("autorest/adal: unexpected error while creating multitenant service principal token: %v", err)
1445 }
1446 if !strings.Contains(mt.PrimaryToken.inner.OauthConfig.AuthorizeEndpoint.String(), TestTenantID) {
1447 t.Fatal("didn't find primary tenant ID in primary SPT")
1448 }
1449 for i := range mt.AuxiliaryTokens {
1450 if ep := mt.AuxiliaryTokens[i].inner.OauthConfig.AuthorizeEndpoint.String(); !strings.Contains(ep, fmt.Sprintf("%s%d", TestAuxTenantPrefix, i)) {
1451 t.Fatalf("didn't find auxiliary tenant ID in token %s", ep)
1452 }
1453 }
1454 }
1455
1456 func TestNewMultiTenantServicePrincipalTokenFromCertificate(t *testing.T) {
1457 cfg, err := NewMultiTenantOAuthConfig(TestActiveDirectoryEndpoint, TestTenantID, TestAuxTenantIDs, OAuthOptions{})
1458 if err != nil {
1459 t.Fatalf("autorest/adal: unexpected error while creating multitenant config: %v", err)
1460 }
1461 cert, key := newTestCertificate(t)
1462 mt, err := NewMultiTenantServicePrincipalTokenFromCertificate(cfg, "clientID", cert, key, "resource")
1463 if err != nil {
1464 t.Fatalf("autorest/adal: unexpected error while creating multitenant service principal token: %v", err)
1465 }
1466 if !strings.Contains(mt.PrimaryToken.inner.OauthConfig.AuthorizeEndpoint.String(), TestTenantID) {
1467 t.Fatal("didn't find primary tenant ID in primary SPT")
1468 }
1469 for i := range mt.AuxiliaryTokens {
1470 if ep := mt.AuxiliaryTokens[i].inner.OauthConfig.AuthorizeEndpoint.String(); !strings.Contains(ep, fmt.Sprintf("%s%d", TestAuxTenantPrefix, i)) {
1471 t.Fatalf("didn't find auxiliary tenant ID in token %s", ep)
1472 }
1473 }
1474 }
1475
1476 func TestMSIAvailableSuccess(t *testing.T) {
1477 c := mocks.NewSender()
1478 c.AppendResponse(mocks.NewResponse())
1479 if !MSIAvailable(context.Background(), c) {
1480 t.Fatal("unexpected false")
1481 }
1482 }
1483
1484 func TestMSIAvailableAppService(t *testing.T) {
1485 os.Setenv("MSI_ENDPOINT", "http://localhost")
1486 os.Setenv("MSI_SECRET", "super")
1487 defer func() {
1488 os.Unsetenv("MSI_ENDPOINT")
1489 os.Unsetenv("MSI_SECRET")
1490 }()
1491 c := mocks.NewSender()
1492 c.AppendResponse(mocks.NewResponse())
1493 available := MSIAvailable(context.Background(), c)
1494
1495 if !available {
1496 t.Fatal("expected MSI to be available")
1497 }
1498 }
1499
1500 func TestMSIAvailableIMDS(t *testing.T) {
1501 c := mocks.NewSender()
1502 c.AppendResponse(mocks.NewResponse())
1503 available := MSIAvailable(context.Background(), c)
1504
1505 if !available {
1506 t.Fatal("expected MSI to be available")
1507 }
1508 }
1509
1510 func TestMSIAvailableSlow(t *testing.T) {
1511 c := mocks.NewSender()
1512
1513 c.AppendResponseWithDelay(mocks.NewResponse(), 5*time.Second)
1514 if MSIAvailable(context.Background(), c) {
1515 t.Fatal("unexpected true")
1516 }
1517 }
1518
1519 func TestMSIAvailableFail(t *testing.T) {
1520 expectErr := "failed to make msi http request"
1521 c := mocks.NewSender()
1522 c.AppendAndRepeatError(fmt.Errorf(expectErr), 2)
1523 if MSIAvailable(context.Background(), c) {
1524 t.Fatal("unexpected true")
1525 }
1526 _, err := getMSIEndpoint(context.Background(), c)
1527 if !strings.Contains(err.Error(), "") {
1528 t.Fatalf("expected error: '%s', but got error '%s'", expectErr, err)
1529 }
1530 }
1531
1532 func newTokenJSON(expiresIn, expiresOn, resource string) string {
1533 nb, err := parseExpiresOn(expiresOn)
1534 if err != nil {
1535 panic(err)
1536 }
1537 return fmt.Sprintf(`{
1538 "access_token" : "accessToken",
1539 "expires_in" : %s,
1540 "expires_on" : "%s",
1541 "not_before" : "%s",
1542 "resource" : "%s",
1543 "token_type" : "Bearer",
1544 "refresh_token": "ABC123"
1545 }`,
1546 expiresIn, expiresOn, nb, resource)
1547 }
1548
1549 func newTokenJSONIntExpiresOn(expiresIn string, expiresOn int, resource string) string {
1550 return fmt.Sprintf(`{
1551 "access_token" : "accessToken",
1552 "expires_in" : %s,
1553 "expires_on" : %d,
1554 "not_before" : "%d",
1555 "resource" : "%s",
1556 "token_type" : "Bearer",
1557 "refresh_token": "ABC123"
1558 }`,
1559 expiresIn, expiresOn, expiresOn, resource)
1560 }
1561
1562 func newADFSTokenJSON(expiresIn int) string {
1563 return fmt.Sprintf(`{
1564 "access_token" : "accessToken",
1565 "expires_in" : %d,
1566 "token_type" : "Bearer"
1567 }`,
1568 expiresIn)
1569 }
1570
1571 func newTokenExpiresIn(expireIn time.Duration) *Token {
1572 t := newToken()
1573 return setTokenToExpireIn(&t, expireIn)
1574 }
1575
1576 func newTokenExpiresAt(expireAt time.Time) *Token {
1577 t := newToken()
1578 return setTokenToExpireAt(&t, expireAt)
1579 }
1580
1581 func expireToken(t *Token) *Token {
1582 return setTokenToExpireIn(t, 0)
1583 }
1584
1585 func setTokenToExpireAt(t *Token, expireAt time.Time) *Token {
1586 t.ExpiresIn = "3600"
1587 t.ExpiresOn = json.Number(strconv.FormatInt(int64(expireAt.Sub(date.UnixEpoch())/time.Second), 10))
1588 t.NotBefore = t.ExpiresOn
1589 return t
1590 }
1591
1592 func setTokenToExpireIn(t *Token, expireIn time.Duration) *Token {
1593 return setTokenToExpireAt(t, time.Now().Add(expireIn))
1594 }
1595
1596 func newServicePrincipalToken(callbacks ...TokenRefreshCallback) *ServicePrincipalToken {
1597 spt, _ := NewServicePrincipalToken(TestOAuthConfig, "id", "secret", "resource", callbacks...)
1598 return spt
1599 }
1600
1601 func newServicePrincipalTokenManual() *ServicePrincipalToken {
1602 token := newToken()
1603 token.RefreshToken = "refreshtoken"
1604 spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", token)
1605 return spt
1606 }
1607
1608 func newTestCertificate(t *testing.T) (*x509.Certificate, *rsa.PrivateKey) {
1609 template := x509.Certificate{
1610 SerialNumber: big.NewInt(0),
1611 Subject: pkix.Name{CommonName: "test"},
1612 BasicConstraintsValid: true,
1613 }
1614 privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
1615 if err != nil {
1616 t.Fatal(err)
1617 }
1618 certificateBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
1619 if err != nil {
1620 t.Fatal(err)
1621 }
1622 certificate, err := x509.ParseCertificate(certificateBytes)
1623 if err != nil {
1624 t.Fatal(err)
1625 }
1626 return certificate, privateKey
1627 }
1628
1629 func newServicePrincipalTokenCertificate(t *testing.T) *ServicePrincipalToken {
1630 certificate, privateKey := newTestCertificate(t)
1631
1632 spt, _ := NewServicePrincipalTokenFromCertificate(TestOAuthConfig, "id", certificate, privateKey, "resource")
1633 return spt
1634 }
1635
1636 func newServicePrincipalTokenUsernamePassword(t *testing.T) *ServicePrincipalToken {
1637 spt, _ := NewServicePrincipalTokenFromUsernamePassword(TestOAuthConfig, "id", "username", "password", "resource")
1638 return spt
1639 }
1640
1641 func newServicePrincipalTokenAuthorizationCode(t *testing.T) *ServicePrincipalToken {
1642 spt, _ := NewServicePrincipalTokenFromAuthorizationCode(TestOAuthConfig, "id", "clientSecret", "code", "http://redirectUri/getToken", "resource")
1643 return spt
1644 }
1645
1646 func newServicePrincipalTokenFederatedJwt(t *testing.T) *ServicePrincipalToken {
1647 token := jwt.New(jwt.SigningMethodHS256)
1648 token.Header["typ"] = "JWT"
1649 token.Claims = jwt.MapClaims{
1650 "aud": "testAudience",
1651 "iss": "id",
1652 "sub": "id",
1653 "nbf": time.Now().Unix(),
1654 "exp": time.Now().Add(24 * time.Hour).Unix(),
1655 }
1656
1657 signedString, err := token.SignedString([]byte("test key"))
1658 if err != nil {
1659 t.Fatal(err)
1660 }
1661 spt, _ := NewServicePrincipalTokenFromFederatedToken(TestOAuthConfig, "id", signedString, "resource")
1662 return spt
1663 }
1664
1665 func newServicePrincipalTokenFederatedJwtCallback(t *testing.T, callback JWTCallback, fakeEndpoint string) *ServicePrincipalToken {
1666 outhConfig, _ := NewOAuthConfig(fakeEndpoint, TestTenantID)
1667 spt, _ := NewServicePrincipalTokenFromFederatedTokenCallback(*outhConfig, "id", callback, "resource")
1668 return spt
1669 }
1670
View as plain text