1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package githubapp
16
17 import (
18 "context"
19 "fmt"
20 "net/http"
21 "net/url"
22 "regexp"
23 "strings"
24 "time"
25
26 "github.com/bradleyfalzon/ghinstallation/v2"
27 "github.com/google/go-github/v47/github"
28 "github.com/gregjones/httpcache"
29 "github.com/pkg/errors"
30 "github.com/shurcooL/githubv4"
31 "golang.org/x/oauth2"
32 )
33
34 type ClientCreator interface {
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55 NewAppClient() (*github.Client, error)
56
57
58 NewAppV4Client() (*githubv4.Client, error)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 NewInstallationClient(installationID int64) (*github.Client, error)
79
80
81 NewInstallationV4Client(installationID int64) (*githubv4.Client, error)
82
83
84 NewTokenSourceClient(ts oauth2.TokenSource) (*github.Client, error)
85
86
87 NewTokenSourceV4Client(ts oauth2.TokenSource) (*githubv4.Client, error)
88
89
90 NewTokenClient(token string) (*github.Client, error)
91
92
93 NewTokenV4Client(token string) (*githubv4.Client, error)
94 }
95
96 var (
97 maxAgeRegex = regexp.MustCompile(`max-age=\d+`)
98 )
99
100 type key string
101
102 const installationKey = key("installationID")
103
104
105
106 func NewClientCreator(v3BaseURL, v4BaseURL string, integrationID int64, privKeyBytes []byte, opts ...ClientOption) ClientCreator {
107 cc := &clientCreator{
108 v3BaseURL: v3BaseURL,
109 v4BaseURL: v4BaseURL,
110 integrationID: integrationID,
111 privKeyBytes: privKeyBytes,
112 }
113
114 for _, opt := range opts {
115 opt(cc)
116 }
117
118 if !strings.HasSuffix(cc.v3BaseURL, "/") {
119 cc.v3BaseURL += "/"
120 }
121
122
123 cc.v4BaseURL = strings.TrimSuffix(cc.v4BaseURL, "/")
124
125 return cc
126 }
127
128 type clientCreator struct {
129 v3BaseURL string
130 v4BaseURL string
131 integrationID int64
132 privKeyBytes []byte
133 userAgent string
134 middleware []ClientMiddleware
135 cacheFunc func() httpcache.Cache
136 alwaysValidate bool
137 timeout time.Duration
138 transport http.RoundTripper
139 }
140
141 var _ ClientCreator = &clientCreator{}
142
143 type ClientOption func(c *clientCreator)
144
145
146
147 type ClientMiddleware func(http.RoundTripper) http.RoundTripper
148
149
150 func WithClientUserAgent(agent string) ClientOption {
151 return func(c *clientCreator) {
152 c.userAgent = agent
153 }
154 }
155
156
157
158
159
160
161 func WithClientCaching(alwaysValidate bool, cache func() httpcache.Cache) ClientOption {
162 return func(c *clientCreator) {
163 c.cacheFunc = cache
164 c.alwaysValidate = alwaysValidate
165 }
166 }
167
168
169 func WithClientTimeout(timeout time.Duration) ClientOption {
170 return func(c *clientCreator) {
171 c.timeout = timeout
172 }
173 }
174
175
176 func WithClientMiddleware(middleware ...ClientMiddleware) ClientOption {
177 return func(c *clientCreator) {
178 c.middleware = middleware
179 }
180 }
181
182
183
184
185 func WithTransport(transport http.RoundTripper) ClientOption {
186 return func(c *clientCreator) {
187 c.transport = transport
188 }
189 }
190
191 func (c *clientCreator) NewAppClient() (*github.Client, error) {
192 base := c.newHTTPClient()
193 installation, transportError := newAppInstallation(c.integrationID, c.privKeyBytes, c.v3BaseURL)
194
195 middleware := []ClientMiddleware{installation}
196 if c.cacheFunc != nil {
197 middleware = append(middleware, cache(c.cacheFunc), cacheControl(c.alwaysValidate))
198 }
199
200 client, err := c.newClient(base, middleware, "application", 0)
201 if err != nil {
202 return nil, err
203 }
204 if *transportError != nil {
205 return nil, *transportError
206 }
207 return client, nil
208 }
209
210 func (c *clientCreator) NewAppV4Client() (*githubv4.Client, error) {
211 base := c.newHTTPClient()
212 installation, transportError := newAppInstallation(c.integrationID, c.privKeyBytes, c.v3BaseURL)
213
214
215
216 middleware := []ClientMiddleware{installation}
217
218 client, err := c.newV4Client(base, middleware, "application")
219 if err != nil {
220 return nil, err
221 }
222 if *transportError != nil {
223 return nil, *transportError
224 }
225 return client, nil
226 }
227
228 func (c *clientCreator) NewInstallationClient(installationID int64) (*github.Client, error) {
229 base := c.newHTTPClient()
230 installation, transportError := newInstallation(c.integrationID, installationID, c.privKeyBytes, c.v3BaseURL)
231
232 middleware := []ClientMiddleware{installation}
233 if c.cacheFunc != nil {
234 middleware = append(middleware, cache(c.cacheFunc), cacheControl(c.alwaysValidate))
235 }
236
237 client, err := c.newClient(base, middleware, fmt.Sprintf("installation: %d", installationID), installationID)
238 if err != nil {
239 return nil, err
240 }
241 if *transportError != nil {
242 return nil, *transportError
243 }
244 return client, nil
245 }
246
247 func (c *clientCreator) NewInstallationV4Client(installationID int64) (*githubv4.Client, error) {
248 base := c.newHTTPClient()
249 installation, transportError := newInstallation(c.integrationID, installationID, c.privKeyBytes, c.v3BaseURL)
250
251
252
253 middleware := []ClientMiddleware{installation}
254
255 client, err := c.newV4Client(base, middleware, fmt.Sprintf("installation: %d", installationID))
256 if err != nil {
257 return nil, err
258 }
259 if *transportError != nil {
260 return nil, *transportError
261 }
262 return client, nil
263 }
264
265 func (c *clientCreator) NewTokenClient(token string) (*github.Client, error) {
266 ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
267 return c.NewTokenSourceClient(ts)
268 }
269
270 func (c *clientCreator) NewTokenSourceClient(ts oauth2.TokenSource) (*github.Client, error) {
271 tc := oauth2.NewClient(context.Background(), ts)
272
273 middleware := []ClientMiddleware{}
274 if c.cacheFunc != nil {
275 middleware = append(middleware, cache(c.cacheFunc), cacheControl(c.alwaysValidate))
276 }
277
278 return c.newClient(tc, middleware, "oauth token", 0)
279 }
280
281 func (c *clientCreator) NewTokenV4Client(token string) (*githubv4.Client, error) {
282 ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
283 return c.NewTokenSourceV4Client(ts)
284 }
285
286 func (c *clientCreator) NewTokenSourceV4Client(ts oauth2.TokenSource) (*githubv4.Client, error) {
287 tc := oauth2.NewClient(context.Background(), ts)
288
289
290 return c.newV4Client(tc, nil, "oauth token")
291 }
292
293 func (c *clientCreator) newHTTPClient() *http.Client {
294 transport := c.transport
295 if transport == nil {
296
297
298 transport = http.DefaultTransport
299 }
300
301 return &http.Client{
302 Transport: transport,
303 Timeout: c.timeout,
304 }
305 }
306
307 func (c *clientCreator) newClient(base *http.Client, middleware []ClientMiddleware, details string, installID int64) (*github.Client, error) {
308 applyMiddleware(base, [][]ClientMiddleware{
309 {setInstallationID(installID)},
310 c.middleware,
311 middleware,
312 })
313
314 baseURL, err := url.Parse(c.v3BaseURL)
315 if err != nil {
316 return nil, errors.Wrapf(err, "failed to parse base URL: %q", c.v3BaseURL)
317 }
318
319 client := github.NewClient(base)
320 client.BaseURL = baseURL
321 client.UserAgent = makeUserAgent(c.userAgent, details)
322
323 return client, nil
324 }
325
326 func (c *clientCreator) newV4Client(base *http.Client, middleware []ClientMiddleware, details string) (*githubv4.Client, error) {
327 applyMiddleware(base, [][]ClientMiddleware{
328 {setUserAgentHeader(makeUserAgent(c.userAgent, details))},
329 c.middleware,
330 middleware,
331 })
332
333 v4BaseURL, err := url.Parse(c.v4BaseURL)
334 if err != nil {
335 return nil, errors.Wrapf(err, "failed to parse base URL: %q", c.v4BaseURL)
336 }
337
338 client := githubv4.NewEnterpriseClient(v4BaseURL.String(), base)
339 return client, nil
340 }
341
342
343
344
345 func applyMiddleware(base *http.Client, middleware [][]ClientMiddleware) {
346 for i := len(middleware) - 1; i >= 0; i-- {
347 for j := len(middleware[i]) - 1; j >= 0; j-- {
348 base.Transport = middleware[i][j](base.Transport)
349 }
350 }
351 }
352
353 func newAppInstallation(integrationID int64, privKeyBytes []byte, v3BaseURL string) (ClientMiddleware, *error) {
354 var transportError error
355 installation := func(next http.RoundTripper) http.RoundTripper {
356 itr, err := ghinstallation.NewAppsTransport(next, integrationID, privKeyBytes)
357 if err != nil {
358 transportError = err
359 return next
360 }
361
362 itr.BaseURL = strings.TrimSuffix(v3BaseURL, "/")
363 return itr
364 }
365 return installation, &transportError
366 }
367
368 func newInstallation(integrationID, installationID int64, privKeyBytes []byte, v3BaseURL string) (ClientMiddleware, *error) {
369 var transportError error
370 installation := func(next http.RoundTripper) http.RoundTripper {
371 itr, err := ghinstallation.New(next, integrationID, installationID, privKeyBytes)
372 if err != nil {
373 transportError = err
374 return next
375 }
376
377 itr.BaseURL = strings.TrimSuffix(v3BaseURL, "/")
378 return itr
379 }
380 return installation, &transportError
381 }
382
383 func cache(cacheFunc func() httpcache.Cache) ClientMiddleware {
384 return func(next http.RoundTripper) http.RoundTripper {
385 return &httpcache.Transport{
386 Transport: next,
387 Cache: cacheFunc(),
388 MarkCachedResponses: true,
389 }
390 }
391 }
392
393 func cacheControl(alwaysValidate bool) ClientMiddleware {
394 return func(next http.RoundTripper) http.RoundTripper {
395 if !alwaysValidate {
396 return next
397 }
398
399
400
401 return roundTripperFunc(func(r *http.Request) (*http.Response, error) {
402 resp, err := next.RoundTrip(r)
403 if resp != nil {
404 cacheControl := resp.Header.Get("Cache-Control")
405 if cacheControl != "" {
406 newCacheControl := maxAgeRegex.ReplaceAllString(cacheControl, "max-age=0")
407 resp.Header.Set("Cache-Control", newCacheControl)
408 }
409 }
410 return resp, err
411 })
412 }
413 }
414
415 func makeUserAgent(base, details string) string {
416 if base == "" {
417 base = "github-base-app/undefined"
418 }
419 return fmt.Sprintf("%s (%s)", base, details)
420 }
421
422 func setInstallationID(installationID int64) ClientMiddleware {
423 return func(next http.RoundTripper) http.RoundTripper {
424 return roundTripperFunc(func(r *http.Request) (*http.Response, error) {
425 r = r.WithContext(context.WithValue(r.Context(), installationKey, installationID))
426 return next.RoundTrip(r)
427 })
428 }
429 }
430
431 func setUserAgentHeader(agent string) ClientMiddleware {
432 return func(next http.RoundTripper) http.RoundTripper {
433 return roundTripperFunc(func(r *http.Request) (*http.Response, error) {
434 r.Header.Set("User-Agent", agent)
435 return next.RoundTrip(r)
436 })
437 }
438 }
439
View as plain text