...
1
2
3
4
5
6
7
8
9 package jwt
10
11 import (
12 "context"
13 "encoding/json"
14 "fmt"
15 "io"
16 "io/ioutil"
17 "net/http"
18 "net/url"
19 "strings"
20 "time"
21
22 "golang.org/x/oauth2"
23 "golang.org/x/oauth2/internal"
24 "golang.org/x/oauth2/jws"
25 )
26
27 var (
28 defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
29 defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
30 )
31
32
33
34 type Config struct {
35
36
37 Email string
38
39
40
41
42
43
44
45
46
47 PrivateKey []byte
48
49
50
51 PrivateKeyID string
52
53
54 Subject string
55
56
57 Scopes []string
58
59
60 TokenURL string
61
62
63 Expires time.Duration
64
65
66
67
68 Audience string
69
70
71
72 PrivateClaims map[string]interface{}
73
74
75
76 UseIDToken bool
77 }
78
79
80
81 func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
82 return oauth2.ReuseTokenSource(nil, jwtSource{ctx, c})
83 }
84
85
86
87
88
89
90 func (c *Config) Client(ctx context.Context) *http.Client {
91 return oauth2.NewClient(ctx, c.TokenSource(ctx))
92 }
93
94
95
96 type jwtSource struct {
97 ctx context.Context
98 conf *Config
99 }
100
101 func (js jwtSource) Token() (*oauth2.Token, error) {
102 pk, err := internal.ParseKey(js.conf.PrivateKey)
103 if err != nil {
104 return nil, err
105 }
106 hc := oauth2.NewClient(js.ctx, nil)
107 claimSet := &jws.ClaimSet{
108 Iss: js.conf.Email,
109 Scope: strings.Join(js.conf.Scopes, " "),
110 Aud: js.conf.TokenURL,
111 PrivateClaims: js.conf.PrivateClaims,
112 }
113 if subject := js.conf.Subject; subject != "" {
114 claimSet.Sub = subject
115
116
117 claimSet.Prn = subject
118 }
119 if t := js.conf.Expires; t > 0 {
120 claimSet.Exp = time.Now().Add(t).Unix()
121 }
122 if aud := js.conf.Audience; aud != "" {
123 claimSet.Aud = aud
124 }
125 h := *defaultHeader
126 h.KeyID = js.conf.PrivateKeyID
127 payload, err := jws.Encode(&h, claimSet, pk)
128 if err != nil {
129 return nil, err
130 }
131 v := url.Values{}
132 v.Set("grant_type", defaultGrantType)
133 v.Set("assertion", payload)
134 resp, err := hc.PostForm(js.conf.TokenURL, v)
135 if err != nil {
136 return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
137 }
138 defer resp.Body.Close()
139 body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
140 if err != nil {
141 return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
142 }
143 if c := resp.StatusCode; c < 200 || c > 299 {
144 return nil, &oauth2.RetrieveError{
145 Response: resp,
146 Body: body,
147 }
148 }
149
150 var tokenRes struct {
151 AccessToken string `json:"access_token"`
152 TokenType string `json:"token_type"`
153 IDToken string `json:"id_token"`
154 ExpiresIn int64 `json:"expires_in"`
155 }
156 if err := json.Unmarshal(body, &tokenRes); err != nil {
157 return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
158 }
159 token := &oauth2.Token{
160 AccessToken: tokenRes.AccessToken,
161 TokenType: tokenRes.TokenType,
162 }
163 raw := make(map[string]interface{})
164 json.Unmarshal(body, &raw)
165 token = token.WithExtra(raw)
166
167 if secs := tokenRes.ExpiresIn; secs > 0 {
168 token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
169 }
170 if v := tokenRes.IDToken; v != "" {
171
172 claimSet, err := jws.Decode(v)
173 if err != nil {
174 return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err)
175 }
176 token.Expiry = time.Unix(claimSet.Exp, 0)
177 }
178 if js.conf.UseIDToken {
179 if tokenRes.IDToken == "" {
180 return nil, fmt.Errorf("oauth2: response doesn't have JWT token")
181 }
182 token.AccessToken = tokenRes.IDToken
183 }
184 return token, nil
185 }
186
View as plain text