...

Source file src/edge-infra.dev/pkg/f8n/devinfra/github/oauth/devicecode/devicecode.go

Documentation: edge-infra.dev/pkg/f8n/devinfra/github/oauth/devicecode

     1  // Package devicecode implements functionality for getting oauth tokens to authenticate
     2  // with GitHub on behalf of a user
     3  package devicecode
     4  
     5  import (
     6  	"bytes"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"os/exec"
    13  	"strings"
    14  	"time"
    15  )
    16  
    17  const (
    18  	oathGrantType   = "urn:ietf:params:oauth:grant-type:device_code"
    19  	deviceCodePath  = "/login/device/code"
    20  	accessTokenPath = "/login/oauth/access_token" // #nosec G101 - not an access token
    21  )
    22  
    23  // Client holds basic information to get an oauth token from github
    24  type Client struct {
    25  	BaseURL  string
    26  	ClientID string
    27  }
    28  
    29  type deviceCodeRequest struct {
    30  	ClientID string `json:"client_id"`
    31  	Scope    string `json:"scope"`
    32  }
    33  
    34  type deviceCodeResponse struct {
    35  	DeviceCode      string `json:"device_code"`
    36  	UserCode        string `json:"user_code"`
    37  	VerificationURI string `json:"verification_uri"`
    38  	ExpiresIn       int    `json:"expires_in"`
    39  	Interval        int    `json:"interval"`
    40  }
    41  
    42  type accessTokenRequest struct {
    43  	ClientID   string `json:"client_id"`
    44  	DeviceCode string `json:"device_code"`
    45  	GrantType  string `json:"grant_type"`
    46  }
    47  
    48  // a superset of success/error responses. if 'error' is not empty, only error fields
    49  // will be present, and vice versa
    50  type accessTokenResponse struct {
    51  	AccessToken      string `json:"access_token"`
    52  	TokenType        string `json:"token_type"`
    53  	Scope            string `json:"scope"`
    54  	Error            string `json:"error"`
    55  	ErrorDescription string `json:"error_description"`
    56  	Interval         int    `json:"interval"`
    57  }
    58  
    59  func NewGitHubOauthClient(oauthClientID string) *Client {
    60  	if oauthClientID == "" {
    61  		return nil
    62  	}
    63  	return &Client{
    64  		BaseURL:  "https://github.com",
    65  		ClientID: oauthClientID,
    66  	}
    67  }
    68  
    69  // DeviceCodeAuthToken performs the device based oauth flow for obtaining a user access token
    70  // returns a non-empty access token string and nil err, or empty string and an error. May
    71  // cache or retrieve token from macos keychain, if possible. If cached token is not found,
    72  // this function blocks on user input.
    73  // See: https://docs.github.com/en/developers/apps/authorizing-oauth-apps#device-flow
    74  func (g *Client) DeviceCodeAuthToken() (string, error) {
    75  	token, found := keychainFindToken()
    76  	if found {
    77  		return token, nil
    78  	}
    79  
    80  	// Step 1: App requests the device and user verification codes from GitHub
    81  	authResp, err := g.getDeviceCode()
    82  	if err != nil {
    83  		return "", err
    84  	}
    85  
    86  	// Step 2: Prompt the user to enter the user code in a browser
    87  	fmt.Printf(
    88  		"Your device code is: %s Enter this code at: %s\n",
    89  		authResp.UserCode,
    90  		authResp.VerificationURI,
    91  	)
    92  
    93  	// Step 3: App polls GitHub to check if the user authorized the device
    94  	interval := time.Duration(authResp.Interval) * time.Second
    95  	tokenResp, err := g.waitForAccessToken(authResp.DeviceCode, interval)
    96  	if err != nil {
    97  		return "", err
    98  	}
    99  
   100  	token = tokenResp.AccessToken
   101  	if err = keychainAddToken(token); err != nil {
   102  		fmt.Printf(
   103  			"warning: failed to store token in local keychain, continuing. error: %v\n",
   104  			err.Error(),
   105  		)
   106  	}
   107  
   108  	return token, nil
   109  }
   110  
   111  func (g *Client) accessTokenURL() string {
   112  	return fmt.Sprintf("%s%s", g.BaseURL, accessTokenPath)
   113  }
   114  
   115  func (g *Client) deviceCodeURL() string {
   116  	return fmt.Sprintf("%s%s", g.BaseURL, deviceCodePath)
   117  }
   118  
   119  func setGitHubJSONHeaders(r *http.Request) {
   120  	r.Header.Add("Accept", "application/vnd.github.v3+json")
   121  	r.Header.Add("Content-Type", "application/json")
   122  }
   123  
   124  func (g *Client) getDeviceCode() (*deviceCodeResponse, error) {
   125  	// create request
   126  	authReqRaw, err := json.Marshal(&deviceCodeRequest{ClientID: g.ClientID, Scope: "repo delete_repo"})
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  	authReqBody := bytes.NewReader(authReqRaw)
   131  	authRequest, err := http.NewRequest("POST", g.deviceCodeURL(), authReqBody)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	setGitHubJSONHeaders(authRequest)
   136  
   137  	// send request
   138  	authResponse, err := http.DefaultClient.Do(authRequest)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	// parse response
   144  	authRespRaw, err := io.ReadAll(authResponse.Body)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	authResp := &deviceCodeResponse{}
   149  	err = json.Unmarshal(authRespRaw, authResp)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	return authResp, nil
   155  }
   156  
   157  func (g *Client) waitForAccessToken(deviceCode string, interval time.Duration) (*accessTokenResponse, error) {
   158  	// wait a small period, then poll verification uri until user has entered code
   159  	// but no longer than the expiration in the auth response
   160  	time.Sleep(interval)
   161  
   162  	tokenReqRaw, err := json.Marshal(&accessTokenRequest{
   163  		ClientID:   g.ClientID,
   164  		DeviceCode: deviceCode,
   165  		GrantType:  oathGrantType,
   166  	})
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	for {
   172  		tokenReqBody := bytes.NewReader(tokenReqRaw)
   173  		tokenRequest, err := http.NewRequest(
   174  			"POST",
   175  			g.accessTokenURL(),
   176  			tokenReqBody,
   177  		)
   178  		if err != nil {
   179  			return nil, err
   180  		}
   181  		setGitHubJSONHeaders(tokenRequest)
   182  		tokenResRaw, err := http.DefaultClient.Do(tokenRequest)
   183  		if err != nil {
   184  			return nil, err
   185  		}
   186  
   187  		tokenResBody, err := io.ReadAll(tokenResRaw.Body)
   188  		if err != nil {
   189  			return nil, err
   190  		}
   191  		if err = tokenResRaw.Body.Close(); err != nil {
   192  			return nil, err
   193  		}
   194  		tokenRes := &accessTokenResponse{}
   195  		if err = json.Unmarshal(tokenResBody, tokenRes); err != nil {
   196  			return nil, err
   197  		}
   198  		if tokenRes.Error == "" {
   199  			return tokenRes, nil
   200  		}
   201  		switch tokenRes.Error {
   202  		case "authorization_pending":
   203  			// ok. just wait and try again, but avoid default error behavior
   204  		case "access_denied":
   205  			return nil, errors.New("user declined auth request")
   206  		case "slow_down":
   207  			interval = time.Duration(tokenRes.Interval) * time.Second
   208  		default:
   209  			return nil, errors.New(tokenRes.ErrorDescription)
   210  		}
   211  
   212  		time.Sleep(interval)
   213  	}
   214  }
   215  
   216  func keychainAddToken(token string) error {
   217  	command := "security"
   218  	args := []string{"add-internet-password", "-a", "edge-infra", "-s", "dev-edge-ncr-oauth", "-w", token}
   219  	return exec.Command(command, args...).Run()
   220  }
   221  
   222  func keychainFindToken() (string, bool) {
   223  	command := "security"
   224  	args := []string{"find-internet-password", "-a", "edge-infra", "-s", "dev-edge-ncr-oauth", "-w"}
   225  	cmd := exec.Command(command, args...)
   226  	out, err := cmd.Output()
   227  	if err != nil {
   228  		return "", false
   229  	}
   230  	token := strings.TrimSpace(string(out))
   231  	return token, true
   232  }
   233  

View as plain text