...

Source file src/golang.org/x/oauth2/deviceauth.go

Documentation: golang.org/x/oauth2

     1  package oauth2
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"strings"
    12  	"time"
    13  
    14  	"golang.org/x/oauth2/internal"
    15  )
    16  
    17  // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
    18  const (
    19  	errAuthorizationPending = "authorization_pending"
    20  	errSlowDown             = "slow_down"
    21  	errAccessDenied         = "access_denied"
    22  	errExpiredToken         = "expired_token"
    23  )
    24  
    25  // DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
    26  // https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
    27  type DeviceAuthResponse struct {
    28  	// DeviceCode
    29  	DeviceCode string `json:"device_code"`
    30  	// UserCode is the code the user should enter at the verification uri
    31  	UserCode string `json:"user_code"`
    32  	// VerificationURI is where user should enter the user code
    33  	VerificationURI string `json:"verification_uri"`
    34  	// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
    35  	VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
    36  	// Expiry is when the device code and user code expire
    37  	Expiry time.Time `json:"expires_in,omitempty"`
    38  	// Interval is the duration in seconds that Poll should wait between requests
    39  	Interval int64 `json:"interval,omitempty"`
    40  }
    41  
    42  func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
    43  	type Alias DeviceAuthResponse
    44  	var expiresIn int64
    45  	if !d.Expiry.IsZero() {
    46  		expiresIn = int64(time.Until(d.Expiry).Seconds())
    47  	}
    48  	return json.Marshal(&struct {
    49  		ExpiresIn int64 `json:"expires_in,omitempty"`
    50  		*Alias
    51  	}{
    52  		ExpiresIn: expiresIn,
    53  		Alias:     (*Alias)(&d),
    54  	})
    55  
    56  }
    57  
    58  func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
    59  	type Alias DeviceAuthResponse
    60  	aux := &struct {
    61  		ExpiresIn int64 `json:"expires_in"`
    62  		// workaround misspelling of verification_uri
    63  		VerificationURL string `json:"verification_url"`
    64  		*Alias
    65  	}{
    66  		Alias: (*Alias)(c),
    67  	}
    68  	if err := json.Unmarshal(data, &aux); err != nil {
    69  		return err
    70  	}
    71  	if aux.ExpiresIn != 0 {
    72  		c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
    73  	}
    74  	if c.VerificationURI == "" {
    75  		c.VerificationURI = aux.VerificationURL
    76  	}
    77  	return nil
    78  }
    79  
    80  // DeviceAuth returns a device auth struct which contains a device code
    81  // and authorization information provided for users to enter on another device.
    82  func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
    83  	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
    84  	v := url.Values{
    85  		"client_id": {c.ClientID},
    86  	}
    87  	if len(c.Scopes) > 0 {
    88  		v.Set("scope", strings.Join(c.Scopes, " "))
    89  	}
    90  	for _, opt := range opts {
    91  		opt.setValue(v)
    92  	}
    93  	return retrieveDeviceAuth(ctx, c, v)
    94  }
    95  
    96  func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
    97  	if c.Endpoint.DeviceAuthURL == "" {
    98  		return nil, errors.New("endpoint missing DeviceAuthURL")
    99  	}
   100  
   101  	req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   106  	req.Header.Set("Accept", "application/json")
   107  
   108  	t := time.Now()
   109  	r, err := internal.ContextClient(ctx).Do(req)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
   115  	if err != nil {
   116  		return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
   117  	}
   118  	if code := r.StatusCode; code < 200 || code > 299 {
   119  		return nil, &RetrieveError{
   120  			Response: r,
   121  			Body:     body,
   122  		}
   123  	}
   124  
   125  	da := &DeviceAuthResponse{}
   126  	err = json.Unmarshal(body, &da)
   127  	if err != nil {
   128  		return nil, fmt.Errorf("unmarshal %s", err)
   129  	}
   130  
   131  	if !da.Expiry.IsZero() {
   132  		// Make a small adjustment to account for time taken by the request
   133  		da.Expiry = da.Expiry.Add(-time.Since(t))
   134  	}
   135  
   136  	return da, nil
   137  }
   138  
   139  // DeviceAccessToken polls the server to exchange a device code for a token.
   140  func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
   141  	if !da.Expiry.IsZero() {
   142  		var cancel context.CancelFunc
   143  		ctx, cancel = context.WithDeadline(ctx, da.Expiry)
   144  		defer cancel()
   145  	}
   146  
   147  	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
   148  	v := url.Values{
   149  		"client_id":   {c.ClientID},
   150  		"grant_type":  {"urn:ietf:params:oauth:grant-type:device_code"},
   151  		"device_code": {da.DeviceCode},
   152  	}
   153  	if len(c.Scopes) > 0 {
   154  		v.Set("scope", strings.Join(c.Scopes, " "))
   155  	}
   156  	for _, opt := range opts {
   157  		opt.setValue(v)
   158  	}
   159  
   160  	// "If no value is provided, clients MUST use 5 as the default."
   161  	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
   162  	interval := da.Interval
   163  	if interval == 0 {
   164  		interval = 5
   165  	}
   166  
   167  	ticker := time.NewTicker(time.Duration(interval) * time.Second)
   168  	defer ticker.Stop()
   169  	for {
   170  		select {
   171  		case <-ctx.Done():
   172  			return nil, ctx.Err()
   173  		case <-ticker.C:
   174  			tok, err := retrieveToken(ctx, c, v)
   175  			if err == nil {
   176  				return tok, nil
   177  			}
   178  
   179  			e, ok := err.(*RetrieveError)
   180  			if !ok {
   181  				return nil, err
   182  			}
   183  			switch e.ErrorCode {
   184  			case errSlowDown:
   185  				// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
   186  				// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
   187  				interval += 5
   188  				ticker.Reset(time.Duration(interval) * time.Second)
   189  			case errAuthorizationPending:
   190  				// Do nothing.
   191  			case errAccessDenied, errExpiredToken:
   192  				fallthrough
   193  			default:
   194  				return tok, err
   195  			}
   196  		}
   197  	}
   198  }
   199  

View as plain text