1
2
3 package devicecode
4
5 import (
6 "bytes"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "net/http"
12 "os/exec"
13 "strings"
14 "time"
15 )
16
17 const (
18 oathGrantType = "urn:ietf:params:oauth:grant-type:device_code"
19 deviceCodePath = "/login/device/code"
20 accessTokenPath = "/login/oauth/access_token"
21 )
22
23
24 type Client struct {
25 BaseURL string
26 ClientID string
27 }
28
29 type deviceCodeRequest struct {
30 ClientID string `json:"client_id"`
31 Scope string `json:"scope"`
32 }
33
34 type deviceCodeResponse struct {
35 DeviceCode string `json:"device_code"`
36 UserCode string `json:"user_code"`
37 VerificationURI string `json:"verification_uri"`
38 ExpiresIn int `json:"expires_in"`
39 Interval int `json:"interval"`
40 }
41
42 type accessTokenRequest struct {
43 ClientID string `json:"client_id"`
44 DeviceCode string `json:"device_code"`
45 GrantType string `json:"grant_type"`
46 }
47
48
49
50 type accessTokenResponse struct {
51 AccessToken string `json:"access_token"`
52 TokenType string `json:"token_type"`
53 Scope string `json:"scope"`
54 Error string `json:"error"`
55 ErrorDescription string `json:"error_description"`
56 Interval int `json:"interval"`
57 }
58
59 func NewGitHubOauthClient(oauthClientID string) *Client {
60 if oauthClientID == "" {
61 return nil
62 }
63 return &Client{
64 BaseURL: "https://github.com",
65 ClientID: oauthClientID,
66 }
67 }
68
69
70
71
72
73
74 func (g *Client) DeviceCodeAuthToken() (string, error) {
75 token, found := keychainFindToken()
76 if found {
77 return token, nil
78 }
79
80
81 authResp, err := g.getDeviceCode()
82 if err != nil {
83 return "", err
84 }
85
86
87 fmt.Printf(
88 "Your device code is: %s Enter this code at: %s\n",
89 authResp.UserCode,
90 authResp.VerificationURI,
91 )
92
93
94 interval := time.Duration(authResp.Interval) * time.Second
95 tokenResp, err := g.waitForAccessToken(authResp.DeviceCode, interval)
96 if err != nil {
97 return "", err
98 }
99
100 token = tokenResp.AccessToken
101 if err = keychainAddToken(token); err != nil {
102 fmt.Printf(
103 "warning: failed to store token in local keychain, continuing. error: %v\n",
104 err.Error(),
105 )
106 }
107
108 return token, nil
109 }
110
111 func (g *Client) accessTokenURL() string {
112 return fmt.Sprintf("%s%s", g.BaseURL, accessTokenPath)
113 }
114
115 func (g *Client) deviceCodeURL() string {
116 return fmt.Sprintf("%s%s", g.BaseURL, deviceCodePath)
117 }
118
119 func setGitHubJSONHeaders(r *http.Request) {
120 r.Header.Add("Accept", "application/vnd.github.v3+json")
121 r.Header.Add("Content-Type", "application/json")
122 }
123
124 func (g *Client) getDeviceCode() (*deviceCodeResponse, error) {
125
126 authReqRaw, err := json.Marshal(&deviceCodeRequest{ClientID: g.ClientID, Scope: "repo delete_repo"})
127 if err != nil {
128 return nil, err
129 }
130 authReqBody := bytes.NewReader(authReqRaw)
131 authRequest, err := http.NewRequest("POST", g.deviceCodeURL(), authReqBody)
132 if err != nil {
133 return nil, err
134 }
135 setGitHubJSONHeaders(authRequest)
136
137
138 authResponse, err := http.DefaultClient.Do(authRequest)
139 if err != nil {
140 return nil, err
141 }
142
143
144 authRespRaw, err := io.ReadAll(authResponse.Body)
145 if err != nil {
146 return nil, err
147 }
148 authResp := &deviceCodeResponse{}
149 err = json.Unmarshal(authRespRaw, authResp)
150 if err != nil {
151 return nil, err
152 }
153
154 return authResp, nil
155 }
156
157 func (g *Client) waitForAccessToken(deviceCode string, interval time.Duration) (*accessTokenResponse, error) {
158
159
160 time.Sleep(interval)
161
162 tokenReqRaw, err := json.Marshal(&accessTokenRequest{
163 ClientID: g.ClientID,
164 DeviceCode: deviceCode,
165 GrantType: oathGrantType,
166 })
167 if err != nil {
168 return nil, err
169 }
170
171 for {
172 tokenReqBody := bytes.NewReader(tokenReqRaw)
173 tokenRequest, err := http.NewRequest(
174 "POST",
175 g.accessTokenURL(),
176 tokenReqBody,
177 )
178 if err != nil {
179 return nil, err
180 }
181 setGitHubJSONHeaders(tokenRequest)
182 tokenResRaw, err := http.DefaultClient.Do(tokenRequest)
183 if err != nil {
184 return nil, err
185 }
186
187 tokenResBody, err := io.ReadAll(tokenResRaw.Body)
188 if err != nil {
189 return nil, err
190 }
191 if err = tokenResRaw.Body.Close(); err != nil {
192 return nil, err
193 }
194 tokenRes := &accessTokenResponse{}
195 if err = json.Unmarshal(tokenResBody, tokenRes); err != nil {
196 return nil, err
197 }
198 if tokenRes.Error == "" {
199 return tokenRes, nil
200 }
201 switch tokenRes.Error {
202 case "authorization_pending":
203
204 case "access_denied":
205 return nil, errors.New("user declined auth request")
206 case "slow_down":
207 interval = time.Duration(tokenRes.Interval) * time.Second
208 default:
209 return nil, errors.New(tokenRes.ErrorDescription)
210 }
211
212 time.Sleep(interval)
213 }
214 }
215
216 func keychainAddToken(token string) error {
217 command := "security"
218 args := []string{"add-internet-password", "-a", "edge-infra", "-s", "dev-edge-ncr-oauth", "-w", token}
219 return exec.Command(command, args...).Run()
220 }
221
222 func keychainFindToken() (string, bool) {
223 command := "security"
224 args := []string{"find-internet-password", "-a", "edge-infra", "-s", "dev-edge-ncr-oauth", "-w"}
225 cmd := exec.Command(command, args...)
226 out, err := cmd.Output()
227 if err != nil {
228 return "", false
229 }
230 token := strings.TrimSpace(string(out))
231 return token, true
232 }
233
View as plain text