...

Source file src/github.com/Azure/go-autorest/autorest/adal/devicetoken.go

Documentation: github.com/Azure/go-autorest/autorest/adal

     1  package adal
     2  
     3  // Copyright 2017 Microsoft Corporation
     4  //
     5  //  Licensed under the Apache License, Version 2.0 (the "License");
     6  //  you may not use this file except in compliance with the License.
     7  //  You may obtain a copy of the License at
     8  //
     9  //      http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  //  Unless required by applicable law or agreed to in writing, software
    12  //  distributed under the License is distributed on an "AS IS" BASIS,
    13  //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  //  See the License for the specific language governing permissions and
    15  //  limitations under the License.
    16  
    17  /*
    18    This file is largely based on rjw57/oauth2device's code, with the follow differences:
    19     * scope -> resource, and only allow a single one
    20     * receive "Message" in the DeviceCode struct and show it to users as the prompt
    21     * azure-xplat-cli has the following behavior that this emulates:
    22       - does not send client_secret during the token exchange
    23       - sends resource again in the token exchange request
    24  */
    25  
    26  import (
    27  	"context"
    28  	"encoding/json"
    29  	"fmt"
    30  	"io/ioutil"
    31  	"net/http"
    32  	"net/url"
    33  	"strings"
    34  	"time"
    35  )
    36  
    37  const (
    38  	logPrefix = "autorest/adal/devicetoken:"
    39  )
    40  
    41  var (
    42  	// ErrDeviceGeneric represents an unknown error from the token endpoint when using device flow
    43  	ErrDeviceGeneric = fmt.Errorf("%s Error while retrieving OAuth token: Unknown Error", logPrefix)
    44  
    45  	// ErrDeviceAccessDenied represents an access denied error from the token endpoint when using device flow
    46  	ErrDeviceAccessDenied = fmt.Errorf("%s Error while retrieving OAuth token: Access Denied", logPrefix)
    47  
    48  	// ErrDeviceAuthorizationPending represents the server waiting on the user to complete the device flow
    49  	ErrDeviceAuthorizationPending = fmt.Errorf("%s Error while retrieving OAuth token: Authorization Pending", logPrefix)
    50  
    51  	// ErrDeviceCodeExpired represents the server timing out and expiring the code during device flow
    52  	ErrDeviceCodeExpired = fmt.Errorf("%s Error while retrieving OAuth token: Code Expired", logPrefix)
    53  
    54  	// ErrDeviceSlowDown represents the service telling us we're polling too often during device flow
    55  	ErrDeviceSlowDown = fmt.Errorf("%s Error while retrieving OAuth token: Slow Down", logPrefix)
    56  
    57  	// ErrDeviceCodeEmpty represents an empty device code from the device endpoint while using device flow
    58  	ErrDeviceCodeEmpty = fmt.Errorf("%s Error while retrieving device code: Device Code Empty", logPrefix)
    59  
    60  	// ErrOAuthTokenEmpty represents an empty OAuth token from the token endpoint when using device flow
    61  	ErrOAuthTokenEmpty = fmt.Errorf("%s Error while retrieving OAuth token: Token Empty", logPrefix)
    62  
    63  	errCodeSendingFails   = "Error occurred while sending request for Device Authorization Code"
    64  	errCodeHandlingFails  = "Error occurred while handling response from the Device Endpoint"
    65  	errTokenSendingFails  = "Error occurred while sending request with device code for a token"
    66  	errTokenHandlingFails = "Error occurred while handling response from the Token Endpoint (during device flow)"
    67  	errStatusNotOK        = "Error HTTP status != 200"
    68  )
    69  
    70  // DeviceCode is the object returned by the device auth endpoint
    71  // It contains information to instruct the user to complete the auth flow
    72  type DeviceCode struct {
    73  	DeviceCode      *string `json:"device_code,omitempty"`
    74  	UserCode        *string `json:"user_code,omitempty"`
    75  	VerificationURL *string `json:"verification_url,omitempty"`
    76  	ExpiresIn       *int64  `json:"expires_in,string,omitempty"`
    77  	Interval        *int64  `json:"interval,string,omitempty"`
    78  
    79  	Message     *string `json:"message"` // Azure specific
    80  	Resource    string  // store the following, stored when initiating, used when exchanging
    81  	OAuthConfig OAuthConfig
    82  	ClientID    string
    83  }
    84  
    85  // TokenError is the object returned by the token exchange endpoint
    86  // when something is amiss
    87  type TokenError struct {
    88  	Error            *string `json:"error,omitempty"`
    89  	ErrorCodes       []int   `json:"error_codes,omitempty"`
    90  	ErrorDescription *string `json:"error_description,omitempty"`
    91  	Timestamp        *string `json:"timestamp,omitempty"`
    92  	TraceID          *string `json:"trace_id,omitempty"`
    93  }
    94  
    95  // DeviceToken is the object return by the token exchange endpoint
    96  // It can either look like a Token or an ErrorToken, so put both here
    97  // and check for presence of "Error" to know if we are in error state
    98  type deviceToken struct {
    99  	Token
   100  	TokenError
   101  }
   102  
   103  // InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode
   104  // that can be used with CheckForUserCompletion or WaitForUserCompletion.
   105  // Deprecated: use InitiateDeviceAuthWithContext() instead.
   106  func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
   107  	return InitiateDeviceAuthWithContext(context.Background(), sender, oauthConfig, clientID, resource)
   108  }
   109  
   110  // InitiateDeviceAuthWithContext initiates a device auth flow. It returns a DeviceCode
   111  // that can be used with CheckForUserCompletion or WaitForUserCompletion.
   112  func InitiateDeviceAuthWithContext(ctx context.Context, sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
   113  	v := url.Values{
   114  		"client_id": []string{clientID},
   115  		"resource":  []string{resource},
   116  	}
   117  
   118  	s := v.Encode()
   119  	body := ioutil.NopCloser(strings.NewReader(s))
   120  
   121  	req, err := http.NewRequest(http.MethodPost, oauthConfig.DeviceCodeEndpoint.String(), body)
   122  	if err != nil {
   123  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
   124  	}
   125  
   126  	req.ContentLength = int64(len(s))
   127  	req.Header.Set(contentType, mimeTypeFormPost)
   128  	resp, err := sender.Do(req.WithContext(ctx))
   129  	if err != nil {
   130  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
   131  	}
   132  	defer resp.Body.Close()
   133  
   134  	rb, err := ioutil.ReadAll(resp.Body)
   135  	if err != nil {
   136  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
   137  	}
   138  
   139  	if resp.StatusCode != http.StatusOK {
   140  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, errStatusNotOK)
   141  	}
   142  
   143  	if len(strings.Trim(string(rb), " ")) == 0 {
   144  		return nil, ErrDeviceCodeEmpty
   145  	}
   146  
   147  	var code DeviceCode
   148  	err = json.Unmarshal(rb, &code)
   149  	if err != nil {
   150  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
   151  	}
   152  
   153  	code.ClientID = clientID
   154  	code.Resource = resource
   155  	code.OAuthConfig = oauthConfig
   156  
   157  	return &code, nil
   158  }
   159  
   160  // CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint
   161  // to see if the device flow has: been completed, timed out, or otherwise failed
   162  // Deprecated: use CheckForUserCompletionWithContext() instead.
   163  func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
   164  	return CheckForUserCompletionWithContext(context.Background(), sender, code)
   165  }
   166  
   167  // CheckForUserCompletionWithContext takes a DeviceCode and checks with the Azure AD OAuth endpoint
   168  // to see if the device flow has: been completed, timed out, or otherwise failed
   169  func CheckForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
   170  	v := url.Values{
   171  		"client_id":  []string{code.ClientID},
   172  		"code":       []string{*code.DeviceCode},
   173  		"grant_type": []string{OAuthGrantTypeDeviceCode},
   174  		"resource":   []string{code.Resource},
   175  	}
   176  
   177  	s := v.Encode()
   178  	body := ioutil.NopCloser(strings.NewReader(s))
   179  
   180  	req, err := http.NewRequest(http.MethodPost, code.OAuthConfig.TokenEndpoint.String(), body)
   181  	if err != nil {
   182  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
   183  	}
   184  
   185  	req.ContentLength = int64(len(s))
   186  	req.Header.Set(contentType, mimeTypeFormPost)
   187  	resp, err := sender.Do(req.WithContext(ctx))
   188  	if err != nil {
   189  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
   190  	}
   191  	defer resp.Body.Close()
   192  
   193  	rb, err := ioutil.ReadAll(resp.Body)
   194  	if err != nil {
   195  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
   196  	}
   197  
   198  	if resp.StatusCode != http.StatusOK && len(strings.Trim(string(rb), " ")) == 0 {
   199  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, errStatusNotOK)
   200  	}
   201  	if len(strings.Trim(string(rb), " ")) == 0 {
   202  		return nil, ErrOAuthTokenEmpty
   203  	}
   204  
   205  	var token deviceToken
   206  	err = json.Unmarshal(rb, &token)
   207  	if err != nil {
   208  		return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
   209  	}
   210  
   211  	if token.Error == nil {
   212  		return &token.Token, nil
   213  	}
   214  
   215  	switch *token.Error {
   216  	case "authorization_pending":
   217  		return nil, ErrDeviceAuthorizationPending
   218  	case "slow_down":
   219  		return nil, ErrDeviceSlowDown
   220  	case "access_denied":
   221  		return nil, ErrDeviceAccessDenied
   222  	case "code_expired":
   223  		return nil, ErrDeviceCodeExpired
   224  	default:
   225  		// return a more meaningful error message if available
   226  		if token.ErrorDescription != nil {
   227  			return nil, fmt.Errorf("%s %s: %s", logPrefix, *token.Error, *token.ErrorDescription)
   228  		}
   229  		return nil, ErrDeviceGeneric
   230  	}
   231  }
   232  
   233  // WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs.
   234  // This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
   235  // Deprecated: use WaitForUserCompletionWithContext() instead.
   236  func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
   237  	return WaitForUserCompletionWithContext(context.Background(), sender, code)
   238  }
   239  
   240  // WaitForUserCompletionWithContext calls CheckForUserCompletion repeatedly until a token is granted or an error
   241  // state occurs.  This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
   242  func WaitForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
   243  	intervalDuration := time.Duration(*code.Interval) * time.Second
   244  	waitDuration := intervalDuration
   245  
   246  	for {
   247  		token, err := CheckForUserCompletionWithContext(ctx, sender, code)
   248  
   249  		if err == nil {
   250  			return token, nil
   251  		}
   252  
   253  		switch err {
   254  		case ErrDeviceSlowDown:
   255  			waitDuration += waitDuration
   256  		case ErrDeviceAuthorizationPending:
   257  			// noop
   258  		default: // everything else is "fatal" to us
   259  			return nil, err
   260  		}
   261  
   262  		if waitDuration > (intervalDuration * 3) {
   263  			return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix)
   264  		}
   265  
   266  		select {
   267  		case <-time.After(waitDuration):
   268  			// noop
   269  		case <-ctx.Done():
   270  			return nil, ctx.Err()
   271  		}
   272  	}
   273  }
   274  

View as plain text