1 package oauth2
2
3 import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11 "strings"
12 "time"
13
14 "golang.org/x/oauth2/internal"
15 )
16
17
18 const (
19 errAuthorizationPending = "authorization_pending"
20 errSlowDown = "slow_down"
21 errAccessDenied = "access_denied"
22 errExpiredToken = "expired_token"
23 )
24
25
26
27 type DeviceAuthResponse struct {
28
29 DeviceCode string `json:"device_code"`
30
31 UserCode string `json:"user_code"`
32
33 VerificationURI string `json:"verification_uri"`
34
35 VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
36
37 Expiry time.Time `json:"expires_in,omitempty"`
38
39 Interval int64 `json:"interval,omitempty"`
40 }
41
42 func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
43 type Alias DeviceAuthResponse
44 var expiresIn int64
45 if !d.Expiry.IsZero() {
46 expiresIn = int64(time.Until(d.Expiry).Seconds())
47 }
48 return json.Marshal(&struct {
49 ExpiresIn int64 `json:"expires_in,omitempty"`
50 *Alias
51 }{
52 ExpiresIn: expiresIn,
53 Alias: (*Alias)(&d),
54 })
55
56 }
57
58 func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
59 type Alias DeviceAuthResponse
60 aux := &struct {
61 ExpiresIn int64 `json:"expires_in"`
62
63 VerificationURL string `json:"verification_url"`
64 *Alias
65 }{
66 Alias: (*Alias)(c),
67 }
68 if err := json.Unmarshal(data, &aux); err != nil {
69 return err
70 }
71 if aux.ExpiresIn != 0 {
72 c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
73 }
74 if c.VerificationURI == "" {
75 c.VerificationURI = aux.VerificationURL
76 }
77 return nil
78 }
79
80
81
82 func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
83
84 v := url.Values{
85 "client_id": {c.ClientID},
86 }
87 if len(c.Scopes) > 0 {
88 v.Set("scope", strings.Join(c.Scopes, " "))
89 }
90 for _, opt := range opts {
91 opt.setValue(v)
92 }
93 return retrieveDeviceAuth(ctx, c, v)
94 }
95
96 func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
97 if c.Endpoint.DeviceAuthURL == "" {
98 return nil, errors.New("endpoint missing DeviceAuthURL")
99 }
100
101 req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
102 if err != nil {
103 return nil, err
104 }
105 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
106 req.Header.Set("Accept", "application/json")
107
108 t := time.Now()
109 r, err := internal.ContextClient(ctx).Do(req)
110 if err != nil {
111 return nil, err
112 }
113
114 body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
115 if err != nil {
116 return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
117 }
118 if code := r.StatusCode; code < 200 || code > 299 {
119 return nil, &RetrieveError{
120 Response: r,
121 Body: body,
122 }
123 }
124
125 da := &DeviceAuthResponse{}
126 err = json.Unmarshal(body, &da)
127 if err != nil {
128 return nil, fmt.Errorf("unmarshal %s", err)
129 }
130
131 if !da.Expiry.IsZero() {
132
133 da.Expiry = da.Expiry.Add(-time.Since(t))
134 }
135
136 return da, nil
137 }
138
139
140 func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
141 if !da.Expiry.IsZero() {
142 var cancel context.CancelFunc
143 ctx, cancel = context.WithDeadline(ctx, da.Expiry)
144 defer cancel()
145 }
146
147
148 v := url.Values{
149 "client_id": {c.ClientID},
150 "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
151 "device_code": {da.DeviceCode},
152 }
153 if len(c.Scopes) > 0 {
154 v.Set("scope", strings.Join(c.Scopes, " "))
155 }
156 for _, opt := range opts {
157 opt.setValue(v)
158 }
159
160
161
162 interval := da.Interval
163 if interval == 0 {
164 interval = 5
165 }
166
167 ticker := time.NewTicker(time.Duration(interval) * time.Second)
168 defer ticker.Stop()
169 for {
170 select {
171 case <-ctx.Done():
172 return nil, ctx.Err()
173 case <-ticker.C:
174 tok, err := retrieveToken(ctx, c, v)
175 if err == nil {
176 return tok, nil
177 }
178
179 e, ok := err.(*RetrieveError)
180 if !ok {
181 return nil, err
182 }
183 switch e.ErrorCode {
184 case errSlowDown:
185
186
187 interval += 5
188 ticker.Reset(time.Duration(interval) * time.Second)
189 case errAuthorizationPending:
190
191 case errAccessDenied, errExpiredToken:
192 fallthrough
193 default:
194 return tok, err
195 }
196 }
197 }
198 }
199
View as plain text