1 package adal
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 import (
18 "context"
19 "crypto/rand"
20 "crypto/rsa"
21 "crypto/sha1"
22 "crypto/x509"
23 "encoding/base64"
24 "encoding/json"
25 "errors"
26 "fmt"
27 "io"
28 "io/ioutil"
29 "math"
30 "net/http"
31 "net/url"
32 "os"
33 "strconv"
34 "strings"
35 "sync"
36 "time"
37
38 "github.com/Azure/go-autorest/autorest/date"
39 "github.com/Azure/go-autorest/logger"
40 "github.com/golang-jwt/jwt/v4"
41 )
42
43 const (
44 defaultRefresh = 5 * time.Minute
45
46
47 OAuthGrantTypeDeviceCode = "device_code"
48
49
50 OAuthGrantTypeClientCredentials = "client_credentials"
51
52
53 OAuthGrantTypeUserPass = "password"
54
55
56 OAuthGrantTypeRefreshToken = "refresh_token"
57
58
59 OAuthGrantTypeAuthorizationCode = "authorization_code"
60
61
62 metadataHeader = "Metadata"
63
64
65 msiEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
66
67
68 msiAPIVersion = "2018-02-01"
69
70
71 defaultMaxMSIRefreshAttempts = 5
72
73
74 msiEndpointEnv = "MSI_ENDPOINT"
75
76
77 msiSecretEnv = "MSI_SECRET"
78
79
80 appServiceAPIVersion2017 = "2017-09-01"
81
82
83 secretHeader = "Secret"
84
85
86 expiresOnDateFormatPM = "1/2/2006 15:04:05 PM +00:00"
87
88
89 expiresOnDateFormat = "1/2/2006 15:04:05 +00:00"
90 )
91
92
93 type OAuthTokenProvider interface {
94 OAuthToken() string
95 }
96
97
98 type MultitenantOAuthTokenProvider interface {
99 PrimaryOAuthToken() string
100 AuxiliaryOAuthTokens() []string
101 }
102
103
104 type TokenRefreshError interface {
105 error
106 Response() *http.Response
107 }
108
109
110 type Refresher interface {
111 Refresh() error
112 RefreshExchange(resource string) error
113 EnsureFresh() error
114 }
115
116
117 type RefresherWithContext interface {
118 RefreshWithContext(ctx context.Context) error
119 RefreshExchangeWithContext(ctx context.Context, resource string) error
120 EnsureFreshWithContext(ctx context.Context) error
121 }
122
123
124
125 type TokenRefreshCallback func(Token) error
126
127
128 type TokenRefresh func(ctx context.Context, resource string) (*Token, error)
129
130
131 type JWTCallback func() (string, error)
132
133
134
135 type Token struct {
136 AccessToken string `json:"access_token"`
137 RefreshToken string `json:"refresh_token"`
138
139 ExpiresIn json.Number `json:"expires_in"`
140 ExpiresOn json.Number `json:"expires_on"`
141 NotBefore json.Number `json:"not_before"`
142
143 Resource string `json:"resource"`
144 Type string `json:"token_type"`
145 }
146
147 func newToken() Token {
148 return Token{
149 ExpiresIn: "0",
150 ExpiresOn: "0",
151 NotBefore: "0",
152 }
153 }
154
155
156 func (t Token) IsZero() bool {
157 return t == Token{}
158 }
159
160
161 func (t Token) Expires() time.Time {
162 s, err := t.ExpiresOn.Float64()
163 if err != nil {
164 s = -3600
165 }
166
167 expiration := date.NewUnixTimeFromSeconds(s)
168
169 return time.Time(expiration).UTC()
170 }
171
172
173 func (t Token) IsExpired() bool {
174 return t.WillExpireIn(0)
175 }
176
177
178
179 func (t Token) WillExpireIn(d time.Duration) bool {
180 return !t.Expires().After(time.Now().Add(d))
181 }
182
183
184 func (t *Token) OAuthToken() string {
185 return t.AccessToken
186 }
187
188
189
190 type ServicePrincipalSecret interface {
191 SetAuthenticationValues(spt *ServicePrincipalToken, values *url.Values) error
192 }
193
194
195
196 type ServicePrincipalNoSecret struct {
197 }
198
199
200
201 func (noSecret *ServicePrincipalNoSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
202 return fmt.Errorf("Manually created ServicePrincipalToken does not contain secret material to retrieve a new access token")
203 }
204
205
206 func (noSecret ServicePrincipalNoSecret) MarshalJSON() ([]byte, error) {
207 type tokenType struct {
208 Type string `json:"type"`
209 }
210 return json.Marshal(tokenType{
211 Type: "ServicePrincipalNoSecret",
212 })
213 }
214
215
216 type ServicePrincipalTokenSecret struct {
217 ClientSecret string `json:"value"`
218 }
219
220
221
222 func (tokenSecret *ServicePrincipalTokenSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
223 v.Set("client_secret", tokenSecret.ClientSecret)
224 return nil
225 }
226
227
228 func (tokenSecret ServicePrincipalTokenSecret) MarshalJSON() ([]byte, error) {
229 type tokenType struct {
230 Type string `json:"type"`
231 Value string `json:"value"`
232 }
233 return json.Marshal(tokenType{
234 Type: "ServicePrincipalTokenSecret",
235 Value: tokenSecret.ClientSecret,
236 })
237 }
238
239
240 type ServicePrincipalCertificateSecret struct {
241 Certificate *x509.Certificate
242 PrivateKey *rsa.PrivateKey
243 }
244
245
246 func (secret *ServicePrincipalCertificateSecret) SignJwt(spt *ServicePrincipalToken) (string, error) {
247 hasher := sha1.New()
248 _, err := hasher.Write(secret.Certificate.Raw)
249 if err != nil {
250 return "", err
251 }
252
253 thumbprint := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
254
255
256 jti := make([]byte, 20)
257 _, err = rand.Read(jti)
258 if err != nil {
259 return "", err
260 }
261
262 token := jwt.New(jwt.SigningMethodRS256)
263 token.Header["x5t"] = thumbprint
264 x5c := []string{base64.StdEncoding.EncodeToString(secret.Certificate.Raw)}
265 token.Header["x5c"] = x5c
266 token.Claims = jwt.MapClaims{
267 "aud": spt.inner.OauthConfig.TokenEndpoint.String(),
268 "iss": spt.inner.ClientID,
269 "sub": spt.inner.ClientID,
270 "jti": base64.URLEncoding.EncodeToString(jti),
271 "nbf": time.Now().Unix(),
272 "exp": time.Now().Add(24 * time.Hour).Unix(),
273 }
274
275 signedString, err := token.SignedString(secret.PrivateKey)
276 return signedString, err
277 }
278
279
280
281 func (secret *ServicePrincipalCertificateSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
282 jwt, err := secret.SignJwt(spt)
283 if err != nil {
284 return err
285 }
286
287 v.Set("client_assertion", jwt)
288 v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
289 return nil
290 }
291
292
293 func (secret ServicePrincipalCertificateSecret) MarshalJSON() ([]byte, error) {
294 return nil, errors.New("marshalling ServicePrincipalCertificateSecret is not supported")
295 }
296
297
298 type ServicePrincipalMSISecret struct {
299 msiType msiType
300 clientResourceID string
301 }
302
303
304 func (msiSecret *ServicePrincipalMSISecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
305 return nil
306 }
307
308
309 func (msiSecret ServicePrincipalMSISecret) MarshalJSON() ([]byte, error) {
310 return nil, errors.New("marshalling ServicePrincipalMSISecret is not supported")
311 }
312
313
314 type ServicePrincipalUsernamePasswordSecret struct {
315 Username string `json:"username"`
316 Password string `json:"password"`
317 }
318
319
320 func (secret *ServicePrincipalUsernamePasswordSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
321 v.Set("username", secret.Username)
322 v.Set("password", secret.Password)
323 return nil
324 }
325
326
327 func (secret ServicePrincipalUsernamePasswordSecret) MarshalJSON() ([]byte, error) {
328 type tokenType struct {
329 Type string `json:"type"`
330 Username string `json:"username"`
331 Password string `json:"password"`
332 }
333 return json.Marshal(tokenType{
334 Type: "ServicePrincipalUsernamePasswordSecret",
335 Username: secret.Username,
336 Password: secret.Password,
337 })
338 }
339
340
341 type ServicePrincipalAuthorizationCodeSecret struct {
342 ClientSecret string `json:"value"`
343 AuthorizationCode string `json:"authCode"`
344 RedirectURI string `json:"redirect"`
345 }
346
347
348 func (secret *ServicePrincipalAuthorizationCodeSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error {
349 v.Set("code", secret.AuthorizationCode)
350 v.Set("client_secret", secret.ClientSecret)
351 v.Set("redirect_uri", secret.RedirectURI)
352 return nil
353 }
354
355
356 func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, error) {
357 type tokenType struct {
358 Type string `json:"type"`
359 Value string `json:"value"`
360 AuthCode string `json:"authCode"`
361 Redirect string `json:"redirect"`
362 }
363 return json.Marshal(tokenType{
364 Type: "ServicePrincipalAuthorizationCodeSecret",
365 Value: secret.ClientSecret,
366 AuthCode: secret.AuthorizationCode,
367 Redirect: secret.RedirectURI,
368 })
369 }
370
371
372 type ServicePrincipalFederatedSecret struct {
373 jwtCallback JWTCallback
374 }
375
376
377
378 func (secret *ServicePrincipalFederatedSecret) SetAuthenticationValues(_ *ServicePrincipalToken, v *url.Values) error {
379 jwt, err := secret.jwtCallback()
380 if err != nil {
381 return err
382 }
383
384 v.Set("client_assertion", jwt)
385 v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
386 return nil
387 }
388
389
390 func (secret ServicePrincipalFederatedSecret) MarshalJSON() ([]byte, error) {
391 return nil, errors.New("marshalling ServicePrincipalFederatedSecret is not supported")
392 }
393
394
395 type ServicePrincipalToken struct {
396 inner servicePrincipalToken
397 refreshLock *sync.RWMutex
398 sender Sender
399 customRefreshFunc TokenRefresh
400 refreshCallbacks []TokenRefreshCallback
401
402
403 MaxMSIRefreshAttempts int
404 }
405
406
407 func (spt ServicePrincipalToken) MarshalTokenJSON() ([]byte, error) {
408 return json.Marshal(spt.inner.Token)
409 }
410
411
412 func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCallback) {
413 spt.refreshCallbacks = callbacks
414 }
415
416
417 func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) {
418 spt.customRefreshFunc = customRefreshFunc
419 }
420
421
422 func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
423 return json.Marshal(spt.inner)
424 }
425
426
427 func (spt *ServicePrincipalToken) UnmarshalJSON(data []byte) error {
428
429 raw := map[string]interface{}{}
430 err := json.Unmarshal(data, &raw)
431 if err != nil {
432 return err
433 }
434 secret := raw["secret"].(map[string]interface{})
435 switch secret["type"] {
436 case "ServicePrincipalNoSecret":
437 spt.inner.Secret = &ServicePrincipalNoSecret{}
438 case "ServicePrincipalTokenSecret":
439 spt.inner.Secret = &ServicePrincipalTokenSecret{}
440 case "ServicePrincipalCertificateSecret":
441 return errors.New("unmarshalling ServicePrincipalCertificateSecret is not supported")
442 case "ServicePrincipalMSISecret":
443 return errors.New("unmarshalling ServicePrincipalMSISecret is not supported")
444 case "ServicePrincipalUsernamePasswordSecret":
445 spt.inner.Secret = &ServicePrincipalUsernamePasswordSecret{}
446 case "ServicePrincipalAuthorizationCodeSecret":
447 spt.inner.Secret = &ServicePrincipalAuthorizationCodeSecret{}
448 case "ServicePrincipalFederatedSecret":
449 return errors.New("unmarshalling ServicePrincipalFederatedSecret is not supported")
450 default:
451 return fmt.Errorf("unrecognized token type '%s'", secret["type"])
452 }
453 err = json.Unmarshal(data, &spt.inner)
454 if err != nil {
455 return err
456 }
457
458 if spt.refreshLock == nil {
459 spt.refreshLock = &sync.RWMutex{}
460 }
461 if spt.sender == nil {
462 spt.sender = sender()
463 }
464 return nil
465 }
466
467
468 type servicePrincipalToken struct {
469 Token Token `json:"token"`
470 Secret ServicePrincipalSecret `json:"secret"`
471 OauthConfig OAuthConfig `json:"oauth"`
472 ClientID string `json:"clientID"`
473 Resource string `json:"resource"`
474 AutoRefresh bool `json:"autoRefresh"`
475 RefreshWithin time.Duration `json:"refreshWithin"`
476 }
477
478 func validateOAuthConfig(oac OAuthConfig) error {
479 if oac.IsZero() {
480 return fmt.Errorf("parameter 'oauthConfig' cannot be zero-initialized")
481 }
482 return nil
483 }
484
485
486 func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, resource string, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
487 if err := validateOAuthConfig(oauthConfig); err != nil {
488 return nil, err
489 }
490 if err := validateStringParam(id, "id"); err != nil {
491 return nil, err
492 }
493 if err := validateStringParam(resource, "resource"); err != nil {
494 return nil, err
495 }
496 if secret == nil {
497 return nil, fmt.Errorf("parameter 'secret' cannot be nil")
498 }
499 spt := &ServicePrincipalToken{
500 inner: servicePrincipalToken{
501 Token: newToken(),
502 OauthConfig: oauthConfig,
503 Secret: secret,
504 ClientID: id,
505 Resource: resource,
506 AutoRefresh: true,
507 RefreshWithin: defaultRefresh,
508 },
509 refreshLock: &sync.RWMutex{},
510 sender: sender(),
511 refreshCallbacks: callbacks,
512 }
513 return spt, nil
514 }
515
516
517 func NewServicePrincipalTokenFromManualToken(oauthConfig OAuthConfig, clientID string, resource string, token Token, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
518 if err := validateOAuthConfig(oauthConfig); err != nil {
519 return nil, err
520 }
521 if err := validateStringParam(clientID, "clientID"); err != nil {
522 return nil, err
523 }
524 if err := validateStringParam(resource, "resource"); err != nil {
525 return nil, err
526 }
527 if token.IsZero() {
528 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
529 }
530 spt, err := NewServicePrincipalTokenWithSecret(
531 oauthConfig,
532 clientID,
533 resource,
534 &ServicePrincipalNoSecret{},
535 callbacks...)
536 if err != nil {
537 return nil, err
538 }
539
540 spt.inner.Token = token
541
542 return spt, nil
543 }
544
545
546 func NewServicePrincipalTokenFromManualTokenSecret(oauthConfig OAuthConfig, clientID string, resource string, token Token, secret ServicePrincipalSecret, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
547 if err := validateOAuthConfig(oauthConfig); err != nil {
548 return nil, err
549 }
550 if err := validateStringParam(clientID, "clientID"); err != nil {
551 return nil, err
552 }
553 if err := validateStringParam(resource, "resource"); err != nil {
554 return nil, err
555 }
556 if secret == nil {
557 return nil, fmt.Errorf("parameter 'secret' cannot be nil")
558 }
559 if token.IsZero() {
560 return nil, fmt.Errorf("parameter 'token' cannot be zero-initialized")
561 }
562 spt, err := NewServicePrincipalTokenWithSecret(
563 oauthConfig,
564 clientID,
565 resource,
566 secret,
567 callbacks...)
568 if err != nil {
569 return nil, err
570 }
571
572 spt.inner.Token = token
573
574 return spt, nil
575 }
576
577
578
579 func NewServicePrincipalToken(oauthConfig OAuthConfig, clientID string, secret string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
580 if err := validateOAuthConfig(oauthConfig); err != nil {
581 return nil, err
582 }
583 if err := validateStringParam(clientID, "clientID"); err != nil {
584 return nil, err
585 }
586 if err := validateStringParam(secret, "secret"); err != nil {
587 return nil, err
588 }
589 if err := validateStringParam(resource, "resource"); err != nil {
590 return nil, err
591 }
592 return NewServicePrincipalTokenWithSecret(
593 oauthConfig,
594 clientID,
595 resource,
596 &ServicePrincipalTokenSecret{
597 ClientSecret: secret,
598 },
599 callbacks...,
600 )
601 }
602
603
604 func NewServicePrincipalTokenFromCertificate(oauthConfig OAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
605 if err := validateOAuthConfig(oauthConfig); err != nil {
606 return nil, err
607 }
608 if err := validateStringParam(clientID, "clientID"); err != nil {
609 return nil, err
610 }
611 if err := validateStringParam(resource, "resource"); err != nil {
612 return nil, err
613 }
614 if certificate == nil {
615 return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
616 }
617 if privateKey == nil {
618 return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
619 }
620 return NewServicePrincipalTokenWithSecret(
621 oauthConfig,
622 clientID,
623 resource,
624 &ServicePrincipalCertificateSecret{
625 PrivateKey: privateKey,
626 Certificate: certificate,
627 },
628 callbacks...,
629 )
630 }
631
632
633 func NewServicePrincipalTokenFromUsernamePassword(oauthConfig OAuthConfig, clientID string, username string, password string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
634 if err := validateOAuthConfig(oauthConfig); err != nil {
635 return nil, err
636 }
637 if err := validateStringParam(clientID, "clientID"); err != nil {
638 return nil, err
639 }
640 if err := validateStringParam(username, "username"); err != nil {
641 return nil, err
642 }
643 if err := validateStringParam(password, "password"); err != nil {
644 return nil, err
645 }
646 if err := validateStringParam(resource, "resource"); err != nil {
647 return nil, err
648 }
649 return NewServicePrincipalTokenWithSecret(
650 oauthConfig,
651 clientID,
652 resource,
653 &ServicePrincipalUsernamePasswordSecret{
654 Username: username,
655 Password: password,
656 },
657 callbacks...,
658 )
659 }
660
661
662 func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clientID string, clientSecret string, authorizationCode string, redirectURI string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
663
664 if err := validateOAuthConfig(oauthConfig); err != nil {
665 return nil, err
666 }
667 if err := validateStringParam(clientID, "clientID"); err != nil {
668 return nil, err
669 }
670 if err := validateStringParam(clientSecret, "clientSecret"); err != nil {
671 return nil, err
672 }
673 if err := validateStringParam(authorizationCode, "authorizationCode"); err != nil {
674 return nil, err
675 }
676 if err := validateStringParam(redirectURI, "redirectURI"); err != nil {
677 return nil, err
678 }
679 if err := validateStringParam(resource, "resource"); err != nil {
680 return nil, err
681 }
682
683 return NewServicePrincipalTokenWithSecret(
684 oauthConfig,
685 clientID,
686 resource,
687 &ServicePrincipalAuthorizationCodeSecret{
688 ClientSecret: clientSecret,
689 AuthorizationCode: authorizationCode,
690 RedirectURI: redirectURI,
691 },
692 callbacks...,
693 )
694 }
695
696
697
698
699 func NewServicePrincipalTokenFromFederatedToken(oauthConfig OAuthConfig, clientID string, jwt string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
700 if err := validateOAuthConfig(oauthConfig); err != nil {
701 return nil, err
702 }
703 if err := validateStringParam(clientID, "clientID"); err != nil {
704 return nil, err
705 }
706 if err := validateStringParam(resource, "resource"); err != nil {
707 return nil, err
708 }
709 if jwt == "" {
710 return nil, fmt.Errorf("parameter 'jwt' cannot be empty")
711 }
712 return NewServicePrincipalTokenFromFederatedTokenCallback(
713 oauthConfig,
714 clientID,
715 func() (string, error) {
716 return jwt, nil
717 },
718 resource,
719 callbacks...,
720 )
721 }
722
723
724 func NewServicePrincipalTokenFromFederatedTokenCallback(oauthConfig OAuthConfig, clientID string, jwtCallback JWTCallback, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
725 if err := validateOAuthConfig(oauthConfig); err != nil {
726 return nil, err
727 }
728 if err := validateStringParam(clientID, "clientID"); err != nil {
729 return nil, err
730 }
731 if err := validateStringParam(resource, "resource"); err != nil {
732 return nil, err
733 }
734 if jwtCallback == nil {
735 return nil, fmt.Errorf("parameter 'jwtCallback' cannot be empty")
736 }
737 return NewServicePrincipalTokenWithSecret(
738 oauthConfig,
739 clientID,
740 resource,
741 &ServicePrincipalFederatedSecret{
742 jwtCallback: jwtCallback,
743 },
744 callbacks...,
745 )
746 }
747
748 type msiType int
749
750 const (
751 msiTypeUnavailable msiType = iota
752 msiTypeAppServiceV20170901
753 msiTypeCloudShell
754 msiTypeIMDS
755 )
756
757 func (m msiType) String() string {
758 switch m {
759 case msiTypeAppServiceV20170901:
760 return "AppServiceV20170901"
761 case msiTypeCloudShell:
762 return "CloudShell"
763 case msiTypeIMDS:
764 return "IMDS"
765 default:
766 return fmt.Sprintf("unhandled MSI type %d", m)
767 }
768 }
769
770
771 func getMSIType() (msiType, string, error) {
772 if endpointEnvVar := os.Getenv(msiEndpointEnv); endpointEnvVar != "" {
773
774 if secretEnvVar := os.Getenv(msiSecretEnv); secretEnvVar != "" {
775
776 return msiTypeAppServiceV20170901, endpointEnvVar, nil
777 }
778
779 return msiTypeCloudShell, endpointEnvVar, nil
780 }
781
782 return msiTypeIMDS, msiEndpoint, nil
783 }
784
785
786
787
788 func GetMSIVMEndpoint() (string, error) {
789 return msiEndpoint, nil
790 }
791
792
793
794
795 func GetMSIAppServiceEndpoint() (string, error) {
796 msiType, endpoint, err := getMSIType()
797 if err != nil {
798 return "", err
799 }
800 switch msiType {
801 case msiTypeAppServiceV20170901:
802 return endpoint, nil
803 default:
804 return "", fmt.Errorf("%s is not app service environment", msiType)
805 }
806 }
807
808
809
810 func GetMSIEndpoint() (string, error) {
811 _, endpoint, err := getMSIType()
812 return endpoint, err
813 }
814
815
816
817
818
819 func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
820 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", "", callbacks...)
821 }
822
823
824
825
826
827 func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
828 if err := validateStringParam(userAssignedID, "userAssignedID"); err != nil {
829 return nil, err
830 }
831 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, "", callbacks...)
832 }
833
834
835
836
837
838 func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
839 if err := validateStringParam(identityResourceID, "identityResourceID"); err != nil {
840 return nil, err
841 }
842 return newServicePrincipalTokenFromMSI(msiEndpoint, resource, "", identityResourceID, callbacks...)
843 }
844
845
846 type ManagedIdentityOptions struct {
847
848
849 ClientID string
850
851
852
853 IdentityResourceID string
854 }
855
856
857
858
859
860
861 func NewServicePrincipalTokenFromManagedIdentity(resource string, options *ManagedIdentityOptions, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
862 if options == nil {
863 options = &ManagedIdentityOptions{}
864 }
865 return newServicePrincipalTokenFromMSI("", resource, options.ClientID, options.IdentityResourceID, callbacks...)
866 }
867
868 func newServicePrincipalTokenFromMSI(msiEndpoint, resource, userAssignedID, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
869 if err := validateStringParam(resource, "resource"); err != nil {
870 return nil, err
871 }
872 if userAssignedID != "" && identityResourceID != "" {
873 return nil, errors.New("cannot specify userAssignedID and identityResourceID")
874 }
875 msiType, endpoint, err := getMSIType()
876 if err != nil {
877 logger.Instance.Writef(logger.LogError, "Error determining managed identity environment: %v\n", err)
878 return nil, err
879 }
880 logger.Instance.Writef(logger.LogInfo, "Managed identity environment is %s, endpoint is %s\n", msiType, endpoint)
881 if msiEndpoint != "" {
882 endpoint = msiEndpoint
883 logger.Instance.Writef(logger.LogInfo, "Managed identity custom endpoint is %s\n", endpoint)
884 }
885 msiEndpointURL, err := url.Parse(endpoint)
886 if err != nil {
887 return nil, err
888 }
889
890 if msiType != msiTypeCloudShell {
891 v := url.Values{}
892 v.Set("resource", resource)
893 clientIDParam := "client_id"
894 switch msiType {
895 case msiTypeAppServiceV20170901:
896 clientIDParam = "clientid"
897 v.Set("api-version", appServiceAPIVersion2017)
898 break
899 case msiTypeIMDS:
900 v.Set("api-version", msiAPIVersion)
901 }
902 if userAssignedID != "" {
903 v.Set(clientIDParam, userAssignedID)
904 } else if identityResourceID != "" {
905 v.Set("mi_res_id", identityResourceID)
906 }
907 msiEndpointURL.RawQuery = v.Encode()
908 }
909
910 spt := &ServicePrincipalToken{
911 inner: servicePrincipalToken{
912 Token: newToken(),
913 OauthConfig: OAuthConfig{
914 TokenEndpoint: *msiEndpointURL,
915 },
916 Secret: &ServicePrincipalMSISecret{
917 msiType: msiType,
918 clientResourceID: identityResourceID,
919 },
920 Resource: resource,
921 AutoRefresh: true,
922 RefreshWithin: defaultRefresh,
923 ClientID: userAssignedID,
924 },
925 refreshLock: &sync.RWMutex{},
926 sender: sender(),
927 refreshCallbacks: callbacks,
928 MaxMSIRefreshAttempts: defaultMaxMSIRefreshAttempts,
929 }
930
931 return spt, nil
932 }
933
934
935 type tokenRefreshError struct {
936 message string
937 resp *http.Response
938 }
939
940
941 func (tre tokenRefreshError) Error() string {
942 return tre.message
943 }
944
945
946 func (tre tokenRefreshError) Response() *http.Response {
947 return tre.resp
948 }
949
950 func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError {
951 return tokenRefreshError{message: message, resp: resp}
952 }
953
954
955
956 func (spt *ServicePrincipalToken) EnsureFresh() error {
957 return spt.EnsureFreshWithContext(context.Background())
958 }
959
960
961
962 func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
963
964 if spt.inner.AutoRefresh && spt.Token().WillExpireIn(spt.inner.RefreshWithin) {
965
966 spt.refreshLock.Lock()
967 defer spt.refreshLock.Unlock()
968 if spt.inner.Token.WillExpireIn(spt.inner.RefreshWithin) {
969 return spt.refreshInternal(ctx, spt.inner.Resource)
970 }
971 }
972 return nil
973 }
974
975
976 func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
977 if spt.refreshCallbacks != nil {
978 for _, callback := range spt.refreshCallbacks {
979 err := callback(spt.inner.Token)
980 if err != nil {
981 return fmt.Errorf("adal: TokenRefreshCallback handler failed. Error = '%v'", err)
982 }
983 }
984 }
985 return nil
986 }
987
988
989
990 func (spt *ServicePrincipalToken) Refresh() error {
991 return spt.RefreshWithContext(context.Background())
992 }
993
994
995
996 func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
997 spt.refreshLock.Lock()
998 defer spt.refreshLock.Unlock()
999 return spt.refreshInternal(ctx, spt.inner.Resource)
1000 }
1001
1002
1003
1004 func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
1005 return spt.RefreshExchangeWithContext(context.Background(), resource)
1006 }
1007
1008
1009
1010 func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
1011 spt.refreshLock.Lock()
1012 defer spt.refreshLock.Unlock()
1013 return spt.refreshInternal(ctx, resource)
1014 }
1015
1016 func (spt *ServicePrincipalToken) getGrantType() string {
1017 switch spt.inner.Secret.(type) {
1018 case *ServicePrincipalUsernamePasswordSecret:
1019 return OAuthGrantTypeUserPass
1020 case *ServicePrincipalAuthorizationCodeSecret:
1021 return OAuthGrantTypeAuthorizationCode
1022 default:
1023 return OAuthGrantTypeClientCredentials
1024 }
1025 }
1026
1027 func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
1028 if spt.customRefreshFunc != nil {
1029 token, err := spt.customRefreshFunc(ctx, resource)
1030 if err != nil {
1031 return err
1032 }
1033 spt.inner.Token = *token
1034 return spt.InvokeRefreshCallbacks(spt.inner.Token)
1035 }
1036 req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
1037 if err != nil {
1038 return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
1039 }
1040 req.Header.Add("User-Agent", UserAgent())
1041 req = req.WithContext(ctx)
1042 var resp *http.Response
1043 authBodyFilter := func(b []byte) []byte {
1044 if logger.Level() != logger.LogAuth {
1045 return []byte("**REDACTED** authentication body")
1046 }
1047 return b
1048 }
1049 if msiSecret, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); ok {
1050 switch msiSecret.msiType {
1051 case msiTypeAppServiceV20170901:
1052 req.Method = http.MethodGet
1053 req.Header.Set("secret", os.Getenv(msiSecretEnv))
1054 break
1055 case msiTypeCloudShell:
1056 req.Header.Set("Metadata", "true")
1057 data := url.Values{}
1058 data.Set("resource", spt.inner.Resource)
1059 if spt.inner.ClientID != "" {
1060 data.Set("client_id", spt.inner.ClientID)
1061 } else if msiSecret.clientResourceID != "" {
1062 data.Set("msi_res_id", msiSecret.clientResourceID)
1063 }
1064 req.Body = ioutil.NopCloser(strings.NewReader(data.Encode()))
1065 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
1066 break
1067 case msiTypeIMDS:
1068 req.Method = http.MethodGet
1069 req.Header.Set("Metadata", "true")
1070 break
1071 }
1072 logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter})
1073 resp, err = retryForIMDS(spt.sender, req, spt.MaxMSIRefreshAttempts)
1074 } else {
1075 v := url.Values{}
1076 v.Set("client_id", spt.inner.ClientID)
1077 v.Set("resource", resource)
1078
1079 if spt.inner.Token.RefreshToken != "" {
1080 v.Set("grant_type", OAuthGrantTypeRefreshToken)
1081 v.Set("refresh_token", spt.inner.Token.RefreshToken)
1082
1083
1084 if spt.getGrantType() == OAuthGrantTypeAuthorizationCode {
1085 err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
1086 if err != nil {
1087 return err
1088 }
1089 }
1090 } else {
1091 v.Set("grant_type", spt.getGrantType())
1092 err := spt.inner.Secret.SetAuthenticationValues(spt, &v)
1093 if err != nil {
1094 return err
1095 }
1096 }
1097
1098 s := v.Encode()
1099 body := ioutil.NopCloser(strings.NewReader(s))
1100 req.ContentLength = int64(len(s))
1101 req.Header.Set(contentType, mimeTypeFormPost)
1102 req.Body = body
1103 logger.Instance.WriteRequest(req, logger.Filter{Body: authBodyFilter})
1104 resp, err = spt.sender.Do(req)
1105 }
1106
1107
1108 if err != nil {
1109 return fmt.Errorf("adal: Failed to execute the refresh request. Error = '%v'", err)
1110 } else if resp == nil {
1111 return fmt.Errorf("adal: received nil response and error")
1112 }
1113
1114 logger.Instance.WriteResponse(resp, logger.Filter{Body: authBodyFilter})
1115 defer resp.Body.Close()
1116 rb, err := ioutil.ReadAll(resp.Body)
1117
1118 if resp.StatusCode != http.StatusOK {
1119 if err != nil {
1120 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Failed reading response body: %v Endpoint %s", resp.StatusCode, err, req.URL.String()), resp)
1121 }
1122 return newTokenRefreshError(fmt.Sprintf("adal: Refresh request failed. Status Code = '%d'. Response body: %s Endpoint %s", resp.StatusCode, string(rb), req.URL.String()), resp)
1123 }
1124
1125
1126
1127
1128
1129 if err != nil {
1130 return fmt.Errorf("adal: Failed to read a new service principal token during refresh. Error = '%v'", err)
1131 }
1132 if len(strings.Trim(string(rb), " ")) == 0 {
1133 return fmt.Errorf("adal: Empty service principal token received during refresh")
1134 }
1135 token := struct {
1136 AccessToken string `json:"access_token"`
1137 RefreshToken string `json:"refresh_token"`
1138
1139
1140 ExpiresIn json.Number `json:"expires_in"`
1141
1142 ExpiresOn interface{} `json:"expires_on"`
1143 NotBefore json.Number `json:"not_before"`
1144
1145 Resource string `json:"resource"`
1146 Type string `json:"token_type"`
1147 }{}
1148
1149 err = json.Unmarshal(rb, &token)
1150 if err != nil {
1151 return newTokenRefreshError(fmt.Sprintf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)), resp)
1152 }
1153 expiresOn := json.Number("")
1154
1155 if token.ExpiresOn != nil {
1156 if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil {
1157 return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp)
1158 }
1159 }
1160 spt.inner.Token.AccessToken = token.AccessToken
1161 spt.inner.Token.RefreshToken = token.RefreshToken
1162 spt.inner.Token.ExpiresIn = token.ExpiresIn
1163 spt.inner.Token.ExpiresOn = expiresOn
1164 spt.inner.Token.NotBefore = token.NotBefore
1165 spt.inner.Token.Resource = token.Resource
1166 spt.inner.Token.Type = token.Type
1167
1168 return spt.InvokeRefreshCallbacks(spt.inner.Token)
1169 }
1170
1171
1172 func parseExpiresOn(s interface{}) (json.Number, error) {
1173
1174 asFloat64, ok := s.(float64)
1175 if ok {
1176
1177 return json.Number(strconv.FormatInt(int64(asFloat64), 10)), nil
1178 }
1179 asStr, ok := s.(string)
1180 if !ok {
1181 return "", fmt.Errorf("unexpected expires_on type %T", s)
1182 }
1183
1184 timeToDuration := func(t time.Time) json.Number {
1185 return json.Number(strconv.FormatInt(t.UTC().Unix(), 10))
1186 }
1187 if _, err := json.Number(asStr).Int64(); err == nil {
1188
1189 return json.Number(asStr), nil
1190 } else if eo, err := time.Parse(expiresOnDateFormatPM, asStr); err == nil {
1191 return timeToDuration(eo), nil
1192 } else if eo, err := time.Parse(expiresOnDateFormat, asStr); err == nil {
1193 return timeToDuration(eo), nil
1194 } else {
1195
1196 return json.Number(""), err
1197 }
1198 }
1199
1200
1201 func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http.Response, err error) {
1202
1203 retries := []int{
1204 http.StatusRequestTimeout,
1205 http.StatusTooManyRequests,
1206 http.StatusInternalServerError,
1207 http.StatusBadGateway,
1208 http.StatusServiceUnavailable,
1209 http.StatusGatewayTimeout,
1210 }
1211
1212 retries = append(retries,
1213 http.StatusNotFound,
1214 http.StatusGone,
1215
1216 http.StatusNotImplemented,
1217 http.StatusHTTPVersionNotSupported,
1218 http.StatusVariantAlsoNegotiates,
1219 http.StatusInsufficientStorage,
1220 http.StatusLoopDetected,
1221 http.StatusNotExtended,
1222 http.StatusNetworkAuthenticationRequired)
1223
1224
1225
1226 const maxDelay time.Duration = 60 * time.Second
1227
1228 attempt := 0
1229 delay := time.Duration(0)
1230
1231
1232 if maxAttempts < 1 {
1233 maxAttempts = defaultMaxMSIRefreshAttempts
1234 }
1235
1236 for attempt < maxAttempts {
1237 if resp != nil && resp.Body != nil {
1238 io.Copy(ioutil.Discard, resp.Body)
1239 resp.Body.Close()
1240 }
1241 resp, err = sender.Do(req)
1242
1243 if err == nil && !responseHasStatusCode(resp, retries...) {
1244 return
1245 }
1246
1247
1248
1249 attempt++
1250
1251 delay += (time.Duration(math.Pow(2, float64(attempt))) * time.Second)
1252 if delay > maxDelay {
1253 delay = maxDelay
1254 }
1255
1256 select {
1257 case <-time.After(delay):
1258
1259 case <-req.Context().Done():
1260 err = req.Context().Err()
1261 return
1262 }
1263 }
1264 return
1265 }
1266
1267 func responseHasStatusCode(resp *http.Response, codes ...int) bool {
1268 if resp != nil {
1269 for _, i := range codes {
1270 if i == resp.StatusCode {
1271 return true
1272 }
1273 }
1274 }
1275 return false
1276 }
1277
1278
1279 func (spt *ServicePrincipalToken) SetAutoRefresh(autoRefresh bool) {
1280 spt.inner.AutoRefresh = autoRefresh
1281 }
1282
1283
1284
1285 func (spt *ServicePrincipalToken) SetRefreshWithin(d time.Duration) {
1286 spt.inner.RefreshWithin = d
1287 return
1288 }
1289
1290
1291
1292 func (spt *ServicePrincipalToken) SetSender(s Sender) { spt.sender = s }
1293
1294
1295 func (spt *ServicePrincipalToken) OAuthToken() string {
1296 spt.refreshLock.RLock()
1297 defer spt.refreshLock.RUnlock()
1298 return spt.inner.Token.OAuthToken()
1299 }
1300
1301
1302 func (spt *ServicePrincipalToken) Token() Token {
1303 spt.refreshLock.RLock()
1304 defer spt.refreshLock.RUnlock()
1305 return spt.inner.Token
1306 }
1307
1308
1309 type MultiTenantServicePrincipalToken struct {
1310 PrimaryToken *ServicePrincipalToken
1311 AuxiliaryTokens []*ServicePrincipalToken
1312 }
1313
1314
1315 func (mt *MultiTenantServicePrincipalToken) PrimaryOAuthToken() string {
1316 return mt.PrimaryToken.OAuthToken()
1317 }
1318
1319
1320 func (mt *MultiTenantServicePrincipalToken) AuxiliaryOAuthTokens() []string {
1321 tokens := make([]string, len(mt.AuxiliaryTokens))
1322 for i := range mt.AuxiliaryTokens {
1323 tokens[i] = mt.AuxiliaryTokens[i].OAuthToken()
1324 }
1325 return tokens
1326 }
1327
1328
1329 func NewMultiTenantServicePrincipalToken(multiTenantCfg MultiTenantOAuthConfig, clientID string, secret string, resource string) (*MultiTenantServicePrincipalToken, error) {
1330 if err := validateStringParam(clientID, "clientID"); err != nil {
1331 return nil, err
1332 }
1333 if err := validateStringParam(secret, "secret"); err != nil {
1334 return nil, err
1335 }
1336 if err := validateStringParam(resource, "resource"); err != nil {
1337 return nil, err
1338 }
1339 auxTenants := multiTenantCfg.AuxiliaryTenants()
1340 m := MultiTenantServicePrincipalToken{
1341 AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
1342 }
1343 primary, err := NewServicePrincipalToken(*multiTenantCfg.PrimaryTenant(), clientID, secret, resource)
1344 if err != nil {
1345 return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
1346 }
1347 m.PrimaryToken = primary
1348 for i := range auxTenants {
1349 aux, err := NewServicePrincipalToken(*auxTenants[i], clientID, secret, resource)
1350 if err != nil {
1351 return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
1352 }
1353 m.AuxiliaryTokens[i] = aux
1354 }
1355 return &m, nil
1356 }
1357
1358
1359 func NewMultiTenantServicePrincipalTokenFromCertificate(multiTenantCfg MultiTenantOAuthConfig, clientID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey, resource string) (*MultiTenantServicePrincipalToken, error) {
1360 if err := validateStringParam(clientID, "clientID"); err != nil {
1361 return nil, err
1362 }
1363 if err := validateStringParam(resource, "resource"); err != nil {
1364 return nil, err
1365 }
1366 if certificate == nil {
1367 return nil, fmt.Errorf("parameter 'certificate' cannot be nil")
1368 }
1369 if privateKey == nil {
1370 return nil, fmt.Errorf("parameter 'privateKey' cannot be nil")
1371 }
1372 auxTenants := multiTenantCfg.AuxiliaryTenants()
1373 m := MultiTenantServicePrincipalToken{
1374 AuxiliaryTokens: make([]*ServicePrincipalToken, len(auxTenants)),
1375 }
1376 primary, err := NewServicePrincipalTokenWithSecret(
1377 *multiTenantCfg.PrimaryTenant(),
1378 clientID,
1379 resource,
1380 &ServicePrincipalCertificateSecret{
1381 PrivateKey: privateKey,
1382 Certificate: certificate,
1383 },
1384 )
1385 if err != nil {
1386 return nil, fmt.Errorf("failed to create SPT for primary tenant: %v", err)
1387 }
1388 m.PrimaryToken = primary
1389 for i := range auxTenants {
1390 aux, err := NewServicePrincipalTokenWithSecret(
1391 *auxTenants[i],
1392 clientID,
1393 resource,
1394 &ServicePrincipalCertificateSecret{
1395 PrivateKey: privateKey,
1396 Certificate: certificate,
1397 },
1398 )
1399 if err != nil {
1400 return nil, fmt.Errorf("failed to create SPT for auxiliary tenant: %v", err)
1401 }
1402 m.AuxiliaryTokens[i] = aux
1403 }
1404 return &m, nil
1405 }
1406
1407
1408 func MSIAvailable(ctx context.Context, s Sender) bool {
1409 msiType, _, err := getMSIType()
1410
1411 if err != nil {
1412 return false
1413 }
1414
1415 if msiType != msiTypeIMDS {
1416 return true
1417 }
1418
1419 if s == nil {
1420 s = sender()
1421 }
1422
1423 resp, err := getMSIEndpoint(ctx, s)
1424
1425 if err == nil {
1426 resp.Body.Close()
1427 }
1428
1429 return err == nil
1430 }
1431
View as plain text