1
2
3
4
5
6
7
8
9
10
11
12
13
14 package config
15
16 import (
17 "bytes"
18 "context"
19 "crypto/sha256"
20 "crypto/tls"
21 "crypto/x509"
22 "encoding/json"
23 "fmt"
24 "net"
25 "net/http"
26 "net/url"
27 "os"
28 "path/filepath"
29 "strings"
30 "sync"
31 "time"
32
33 conntrack "github.com/mwitkow/go-conntrack"
34 "golang.org/x/net/http/httpproxy"
35 "golang.org/x/net/http2"
36 "golang.org/x/oauth2"
37 "golang.org/x/oauth2/clientcredentials"
38 "gopkg.in/yaml.v2"
39 )
40
41 var (
42
43 DefaultHTTPClientConfig = HTTPClientConfig{
44 FollowRedirects: true,
45 EnableHTTP2: true,
46 }
47
48
49 defaultHTTPClientOptions = httpClientOptions{
50 keepAlivesEnabled: true,
51 http2Enabled: true,
52
53
54 idleConnTimeout: 5 * time.Minute,
55 }
56 )
57
58 type closeIdler interface {
59 CloseIdleConnections()
60 }
61
62 type TLSVersion uint16
63
64 var TLSVersions = map[string]TLSVersion{
65 "TLS13": (TLSVersion)(tls.VersionTLS13),
66 "TLS12": (TLSVersion)(tls.VersionTLS12),
67 "TLS11": (TLSVersion)(tls.VersionTLS11),
68 "TLS10": (TLSVersion)(tls.VersionTLS10),
69 }
70
71 func (tv *TLSVersion) UnmarshalYAML(unmarshal func(interface{}) error) error {
72 var s string
73 err := unmarshal((*string)(&s))
74 if err != nil {
75 return err
76 }
77 if v, ok := TLSVersions[s]; ok {
78 *tv = v
79 return nil
80 }
81 return fmt.Errorf("unknown TLS version: %s", s)
82 }
83
84 func (tv TLSVersion) MarshalYAML() (interface{}, error) {
85 for s, v := range TLSVersions {
86 if tv == v {
87 return s, nil
88 }
89 }
90 return nil, fmt.Errorf("unknown TLS version: %d", tv)
91 }
92
93
94 func (tv *TLSVersion) UnmarshalJSON(data []byte) error {
95 var s string
96 if err := json.Unmarshal(data, &s); err != nil {
97 return err
98 }
99 if v, ok := TLSVersions[s]; ok {
100 *tv = v
101 return nil
102 }
103 return fmt.Errorf("unknown TLS version: %s", s)
104 }
105
106
107 func (tv TLSVersion) MarshalJSON() ([]byte, error) {
108 for s, v := range TLSVersions {
109 if tv == v {
110 return json.Marshal(s)
111 }
112 }
113 return nil, fmt.Errorf("unknown TLS version: %d", tv)
114 }
115
116
117 func (tv *TLSVersion) String() string {
118 if tv == nil || *tv == 0 {
119 return ""
120 }
121 for s, v := range TLSVersions {
122 if *tv == v {
123 return s
124 }
125 }
126 return fmt.Sprintf("%d", tv)
127 }
128
129
130 type BasicAuth struct {
131 Username string `yaml:"username" json:"username"`
132 UsernameFile string `yaml:"username_file,omitempty" json:"username_file,omitempty"`
133 Password Secret `yaml:"password,omitempty" json:"password,omitempty"`
134 PasswordFile string `yaml:"password_file,omitempty" json:"password_file,omitempty"`
135 }
136
137
138 func (a *BasicAuth) SetDirectory(dir string) {
139 if a == nil {
140 return
141 }
142 a.PasswordFile = JoinDir(dir, a.PasswordFile)
143 a.UsernameFile = JoinDir(dir, a.UsernameFile)
144 }
145
146
147 type Authorization struct {
148 Type string `yaml:"type,omitempty" json:"type,omitempty"`
149 Credentials Secret `yaml:"credentials,omitempty" json:"credentials,omitempty"`
150 CredentialsFile string `yaml:"credentials_file,omitempty" json:"credentials_file,omitempty"`
151 }
152
153
154 func (a *Authorization) SetDirectory(dir string) {
155 if a == nil {
156 return
157 }
158 a.CredentialsFile = JoinDir(dir, a.CredentialsFile)
159 }
160
161
162 type URL struct {
163 *url.URL
164 }
165
166
167 func (u *URL) UnmarshalYAML(unmarshal func(interface{}) error) error {
168 var s string
169 if err := unmarshal(&s); err != nil {
170 return err
171 }
172
173 urlp, err := url.Parse(s)
174 if err != nil {
175 return err
176 }
177 u.URL = urlp
178 return nil
179 }
180
181
182 func (u URL) MarshalYAML() (interface{}, error) {
183 if u.URL != nil {
184 return u.Redacted(), nil
185 }
186 return nil, nil
187 }
188
189
190 func (u URL) Redacted() string {
191 if u.URL == nil {
192 return ""
193 }
194
195 ru := *u.URL
196 if _, ok := ru.User.Password(); ok {
197
198 ru.User = url.UserPassword(ru.User.Username(), "xxxxx")
199 }
200 return ru.String()
201 }
202
203
204 func (u *URL) UnmarshalJSON(data []byte) error {
205 var s string
206 if err := json.Unmarshal(data, &s); err != nil {
207 return err
208 }
209 urlp, err := url.Parse(s)
210 if err != nil {
211 return err
212 }
213 u.URL = urlp
214 return nil
215 }
216
217
218 func (u URL) MarshalJSON() ([]byte, error) {
219 if u.URL != nil {
220 return json.Marshal(u.URL.String())
221 }
222 return []byte("null"), nil
223 }
224
225
226 type OAuth2 struct {
227 ClientID string `yaml:"client_id" json:"client_id"`
228 ClientSecret Secret `yaml:"client_secret" json:"client_secret"`
229 ClientSecretFile string `yaml:"client_secret_file" json:"client_secret_file"`
230 Scopes []string `yaml:"scopes,omitempty" json:"scopes,omitempty"`
231 TokenURL string `yaml:"token_url" json:"token_url"`
232 EndpointParams map[string]string `yaml:"endpoint_params,omitempty" json:"endpoint_params,omitempty"`
233 TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
234 ProxyConfig `yaml:",inline"`
235 }
236
237
238 func (o *OAuth2) UnmarshalYAML(unmarshal func(interface{}) error) error {
239 type plain OAuth2
240 if err := unmarshal((*plain)(o)); err != nil {
241 return err
242 }
243 return o.ProxyConfig.Validate()
244 }
245
246
247 func (o *OAuth2) UnmarshalJSON(data []byte) error {
248 type plain OAuth2
249 if err := json.Unmarshal(data, (*plain)(o)); err != nil {
250 return err
251 }
252 return o.ProxyConfig.Validate()
253 }
254
255
256 func (a *OAuth2) SetDirectory(dir string) {
257 if a == nil {
258 return
259 }
260 a.ClientSecretFile = JoinDir(dir, a.ClientSecretFile)
261 a.TLSConfig.SetDirectory(dir)
262 }
263
264
265 func LoadHTTPConfig(s string) (*HTTPClientConfig, error) {
266 cfg := &HTTPClientConfig{}
267 err := yaml.UnmarshalStrict([]byte(s), cfg)
268 if err != nil {
269 return nil, err
270 }
271 return cfg, nil
272 }
273
274
275 func LoadHTTPConfigFile(filename string) (*HTTPClientConfig, []byte, error) {
276 content, err := os.ReadFile(filename)
277 if err != nil {
278 return nil, nil, err
279 }
280 cfg, err := LoadHTTPConfig(string(content))
281 if err != nil {
282 return nil, nil, err
283 }
284 cfg.SetDirectory(filepath.Dir(filepath.Dir(filename)))
285 return cfg, content, nil
286 }
287
288
289 type HTTPClientConfig struct {
290
291 BasicAuth *BasicAuth `yaml:"basic_auth,omitempty" json:"basic_auth,omitempty"`
292
293 Authorization *Authorization `yaml:"authorization,omitempty" json:"authorization,omitempty"`
294
295 OAuth2 *OAuth2 `yaml:"oauth2,omitempty" json:"oauth2,omitempty"`
296
297
298 BearerToken Secret `yaml:"bearer_token,omitempty" json:"bearer_token,omitempty"`
299
300
301 BearerTokenFile string `yaml:"bearer_token_file,omitempty" json:"bearer_token_file,omitempty"`
302
303 TLSConfig TLSConfig `yaml:"tls_config,omitempty" json:"tls_config,omitempty"`
304
305
306
307 FollowRedirects bool `yaml:"follow_redirects" json:"follow_redirects"`
308
309
310
311 EnableHTTP2 bool `yaml:"enable_http2" json:"enable_http2"`
312
313 ProxyConfig `yaml:",inline"`
314 }
315
316
317 func (c *HTTPClientConfig) SetDirectory(dir string) {
318 if c == nil {
319 return
320 }
321 c.TLSConfig.SetDirectory(dir)
322 c.BasicAuth.SetDirectory(dir)
323 c.Authorization.SetDirectory(dir)
324 c.OAuth2.SetDirectory(dir)
325 c.BearerTokenFile = JoinDir(dir, c.BearerTokenFile)
326 }
327
328
329
330
331 func (c *HTTPClientConfig) Validate() error {
332
333 if len(c.BearerToken) > 0 && len(c.BearerTokenFile) > 0 {
334 return fmt.Errorf("at most one of bearer_token & bearer_token_file must be configured")
335 }
336 if (c.BasicAuth != nil || c.OAuth2 != nil) && (len(c.BearerToken) > 0 || len(c.BearerTokenFile) > 0) {
337 return fmt.Errorf("at most one of basic_auth, oauth2, bearer_token & bearer_token_file must be configured")
338 }
339 if c.BasicAuth != nil && (string(c.BasicAuth.Username) != "" && c.BasicAuth.UsernameFile != "") {
340 return fmt.Errorf("at most one of basic_auth username & username_file must be configured")
341 }
342 if c.BasicAuth != nil && (string(c.BasicAuth.Password) != "" && c.BasicAuth.PasswordFile != "") {
343 return fmt.Errorf("at most one of basic_auth password & password_file must be configured")
344 }
345 if c.Authorization != nil {
346 if len(c.BearerToken) > 0 || len(c.BearerTokenFile) > 0 {
347 return fmt.Errorf("authorization is not compatible with bearer_token & bearer_token_file")
348 }
349 if string(c.Authorization.Credentials) != "" && c.Authorization.CredentialsFile != "" {
350 return fmt.Errorf("at most one of authorization credentials & credentials_file must be configured")
351 }
352 c.Authorization.Type = strings.TrimSpace(c.Authorization.Type)
353 if len(c.Authorization.Type) == 0 {
354 c.Authorization.Type = "Bearer"
355 }
356 if strings.ToLower(c.Authorization.Type) == "basic" {
357 return fmt.Errorf(`authorization type cannot be set to "basic", use "basic_auth" instead`)
358 }
359 if c.BasicAuth != nil || c.OAuth2 != nil {
360 return fmt.Errorf("at most one of basic_auth, oauth2 & authorization must be configured")
361 }
362 } else {
363 if len(c.BearerToken) > 0 {
364 c.Authorization = &Authorization{Credentials: c.BearerToken}
365 c.Authorization.Type = "Bearer"
366 c.BearerToken = ""
367 }
368 if len(c.BearerTokenFile) > 0 {
369 c.Authorization = &Authorization{CredentialsFile: c.BearerTokenFile}
370 c.Authorization.Type = "Bearer"
371 c.BearerTokenFile = ""
372 }
373 }
374 if c.OAuth2 != nil {
375 if c.BasicAuth != nil {
376 return fmt.Errorf("at most one of basic_auth, oauth2 & authorization must be configured")
377 }
378 if len(c.OAuth2.ClientID) == 0 {
379 return fmt.Errorf("oauth2 client_id must be configured")
380 }
381 if len(c.OAuth2.TokenURL) == 0 {
382 return fmt.Errorf("oauth2 token_url must be configured")
383 }
384 if len(c.OAuth2.ClientSecret) > 0 && len(c.OAuth2.ClientSecretFile) > 0 {
385 return fmt.Errorf("at most one of oauth2 client_secret & client_secret_file must be configured")
386 }
387 }
388 if err := c.ProxyConfig.Validate(); err != nil {
389 return err
390 }
391 return nil
392 }
393
394
395 func (c *HTTPClientConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
396 type plain HTTPClientConfig
397 *c = DefaultHTTPClientConfig
398 if err := unmarshal((*plain)(c)); err != nil {
399 return err
400 }
401 return c.Validate()
402 }
403
404
405 func (c *HTTPClientConfig) UnmarshalJSON(data []byte) error {
406 type plain HTTPClientConfig
407 *c = DefaultHTTPClientConfig
408 if err := json.Unmarshal(data, (*plain)(c)); err != nil {
409 return err
410 }
411 return c.Validate()
412 }
413
414
415 func (a *BasicAuth) UnmarshalYAML(unmarshal func(interface{}) error) error {
416 type plain BasicAuth
417 return unmarshal((*plain)(a))
418 }
419
420
421
422 type DialContextFunc func(context.Context, string, string) (net.Conn, error)
423
424 type httpClientOptions struct {
425 dialContextFunc DialContextFunc
426 keepAlivesEnabled bool
427 http2Enabled bool
428 idleConnTimeout time.Duration
429 userAgent string
430 host string
431 }
432
433
434 type HTTPClientOption func(options *httpClientOptions)
435
436
437 func WithDialContextFunc(fn DialContextFunc) HTTPClientOption {
438 return func(opts *httpClientOptions) {
439 opts.dialContextFunc = fn
440 }
441 }
442
443
444 func WithKeepAlivesDisabled() HTTPClientOption {
445 return func(opts *httpClientOptions) {
446 opts.keepAlivesEnabled = false
447 }
448 }
449
450
451 func WithHTTP2Disabled() HTTPClientOption {
452 return func(opts *httpClientOptions) {
453 opts.http2Enabled = false
454 }
455 }
456
457
458 func WithIdleConnTimeout(timeout time.Duration) HTTPClientOption {
459 return func(opts *httpClientOptions) {
460 opts.idleConnTimeout = timeout
461 }
462 }
463
464
465 func WithUserAgent(ua string) HTTPClientOption {
466 return func(opts *httpClientOptions) {
467 opts.userAgent = ua
468 }
469 }
470
471
472 func WithHost(host string) HTTPClientOption {
473 return func(opts *httpClientOptions) {
474 opts.host = host
475 }
476 }
477
478
479 func newClient(rt http.RoundTripper) *http.Client {
480 return &http.Client{Transport: rt}
481 }
482
483
484
485
486 func NewClientFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (*http.Client, error) {
487 rt, err := NewRoundTripperFromConfig(cfg, name, optFuncs...)
488 if err != nil {
489 return nil, err
490 }
491 client := newClient(rt)
492 if !cfg.FollowRedirects {
493 client.CheckRedirect = func(*http.Request, []*http.Request) error {
494 return http.ErrUseLastResponse
495 }
496 }
497 return client, nil
498 }
499
500
501
502
503 func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) {
504 opts := defaultHTTPClientOptions
505 for _, f := range optFuncs {
506 f(&opts)
507 }
508
509 var dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
510
511 if opts.dialContextFunc != nil {
512 dialContext = conntrack.NewDialContextFunc(
513 conntrack.DialWithDialContextFunc((func(context.Context, string, string) (net.Conn, error))(opts.dialContextFunc)),
514 conntrack.DialWithTracing(),
515 conntrack.DialWithName(name))
516 } else {
517 dialContext = conntrack.NewDialContextFunc(
518 conntrack.DialWithTracing(),
519 conntrack.DialWithName(name))
520 }
521
522 newRT := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
523
524
525 var rt http.RoundTripper = &http.Transport{
526 Proxy: cfg.ProxyConfig.Proxy(),
527 ProxyConnectHeader: cfg.ProxyConfig.GetProxyConnectHeader(),
528 MaxIdleConns: 20000,
529 MaxIdleConnsPerHost: 1000,
530 DisableKeepAlives: !opts.keepAlivesEnabled,
531 TLSClientConfig: tlsConfig,
532 DisableCompression: true,
533 IdleConnTimeout: opts.idleConnTimeout,
534 TLSHandshakeTimeout: 10 * time.Second,
535 ExpectContinueTimeout: 1 * time.Second,
536 DialContext: dialContext,
537 }
538 if opts.http2Enabled && cfg.EnableHTTP2 {
539
540
541
542
543
544
545 http2t, err := http2.ConfigureTransports(rt.(*http.Transport))
546 if err != nil {
547 return nil, err
548 }
549 http2t.ReadIdleTimeout = time.Minute
550 }
551
552
553
554 if cfg.Authorization != nil && len(cfg.Authorization.CredentialsFile) > 0 {
555 rt = NewAuthorizationCredentialsFileRoundTripper(cfg.Authorization.Type, cfg.Authorization.CredentialsFile, rt)
556 } else if cfg.Authorization != nil {
557 rt = NewAuthorizationCredentialsRoundTripper(cfg.Authorization.Type, cfg.Authorization.Credentials, rt)
558 }
559
560
561 if len(cfg.BearerToken) > 0 {
562 rt = NewAuthorizationCredentialsRoundTripper("Bearer", cfg.BearerToken, rt)
563 } else if len(cfg.BearerTokenFile) > 0 {
564 rt = NewAuthorizationCredentialsFileRoundTripper("Bearer", cfg.BearerTokenFile, rt)
565 }
566
567 if cfg.BasicAuth != nil {
568 rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.UsernameFile, cfg.BasicAuth.PasswordFile, rt)
569 }
570
571 if cfg.OAuth2 != nil {
572 rt = NewOAuth2RoundTripper(cfg.OAuth2, rt, &opts)
573 }
574
575 if opts.userAgent != "" {
576 rt = NewUserAgentRoundTripper(opts.userAgent, rt)
577 }
578
579 if opts.host != "" {
580 rt = NewHostRoundTripper(opts.host, rt)
581 }
582
583
584 return rt, nil
585 }
586
587 tlsConfig, err := NewTLSConfig(&cfg.TLSConfig)
588 if err != nil {
589 return nil, err
590 }
591
592 if len(cfg.TLSConfig.CAFile) == 0 {
593
594 return newRT(tlsConfig)
595 }
596 return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.roundTripperSettings(), newRT)
597 }
598
599 type authorizationCredentialsRoundTripper struct {
600 authType string
601 authCredentials Secret
602 rt http.RoundTripper
603 }
604
605
606
607 func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials Secret, rt http.RoundTripper) http.RoundTripper {
608 return &authorizationCredentialsRoundTripper{authType, authCredentials, rt}
609 }
610
611 func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
612 if len(req.Header.Get("Authorization")) == 0 {
613 req = cloneRequest(req)
614 req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, string(rt.authCredentials)))
615 }
616 return rt.rt.RoundTrip(req)
617 }
618
619 func (rt *authorizationCredentialsRoundTripper) CloseIdleConnections() {
620 if ci, ok := rt.rt.(closeIdler); ok {
621 ci.CloseIdleConnections()
622 }
623 }
624
625 type authorizationCredentialsFileRoundTripper struct {
626 authType string
627 authCredentialsFile string
628 rt http.RoundTripper
629 }
630
631
632
633
634 func NewAuthorizationCredentialsFileRoundTripper(authType, authCredentialsFile string, rt http.RoundTripper) http.RoundTripper {
635 return &authorizationCredentialsFileRoundTripper{authType, authCredentialsFile, rt}
636 }
637
638 func (rt *authorizationCredentialsFileRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
639 if len(req.Header.Get("Authorization")) == 0 {
640 b, err := os.ReadFile(rt.authCredentialsFile)
641 if err != nil {
642 return nil, fmt.Errorf("unable to read authorization credentials file %s: %w", rt.authCredentialsFile, err)
643 }
644 authCredentials := strings.TrimSpace(string(b))
645
646 req = cloneRequest(req)
647 req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, authCredentials))
648 }
649
650 return rt.rt.RoundTrip(req)
651 }
652
653 func (rt *authorizationCredentialsFileRoundTripper) CloseIdleConnections() {
654 if ci, ok := rt.rt.(closeIdler); ok {
655 ci.CloseIdleConnections()
656 }
657 }
658
659 type basicAuthRoundTripper struct {
660 username string
661 password Secret
662 usernameFile string
663 passwordFile string
664 rt http.RoundTripper
665 }
666
667
668
669 func NewBasicAuthRoundTripper(username string, password Secret, usernameFile, passwordFile string, rt http.RoundTripper) http.RoundTripper {
670 return &basicAuthRoundTripper{username, password, usernameFile, passwordFile, rt}
671 }
672
673 func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
674 var username string
675 var password string
676 if len(req.Header.Get("Authorization")) != 0 {
677 return rt.rt.RoundTrip(req)
678 }
679 if rt.usernameFile != "" {
680 usernameBytes, err := os.ReadFile(rt.usernameFile)
681 if err != nil {
682 return nil, fmt.Errorf("unable to read basic auth username file %s: %w", rt.usernameFile, err)
683 }
684 username = strings.TrimSpace(string(usernameBytes))
685 } else {
686 username = rt.username
687 }
688 if rt.passwordFile != "" {
689 passwordBytes, err := os.ReadFile(rt.passwordFile)
690 if err != nil {
691 return nil, fmt.Errorf("unable to read basic auth password file %s: %w", rt.passwordFile, err)
692 }
693 password = strings.TrimSpace(string(passwordBytes))
694 } else {
695 password = string(rt.password)
696 }
697 req = cloneRequest(req)
698 req.SetBasicAuth(username, password)
699 return rt.rt.RoundTrip(req)
700 }
701
702 func (rt *basicAuthRoundTripper) CloseIdleConnections() {
703 if ci, ok := rt.rt.(closeIdler); ok {
704 ci.CloseIdleConnections()
705 }
706 }
707
708 type oauth2RoundTripper struct {
709 config *OAuth2
710 rt http.RoundTripper
711 next http.RoundTripper
712 secret string
713 mtx sync.RWMutex
714 opts *httpClientOptions
715 client *http.Client
716 }
717
718 func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
719 return &oauth2RoundTripper{
720 config: config,
721 next: next,
722 opts: opts,
723 }
724 }
725
726 func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
727 var (
728 secret string
729 changed bool
730 )
731
732 if rt.config.ClientSecretFile != "" {
733 data, err := os.ReadFile(rt.config.ClientSecretFile)
734 if err != nil {
735 return nil, fmt.Errorf("unable to read oauth2 client secret file %s: %w", rt.config.ClientSecretFile, err)
736 }
737 secret = strings.TrimSpace(string(data))
738 rt.mtx.RLock()
739 changed = secret != rt.secret
740 rt.mtx.RUnlock()
741 } else {
742
743 secret = string(rt.config.ClientSecret)
744 }
745
746 if changed || rt.rt == nil {
747 config := &clientcredentials.Config{
748 ClientID: rt.config.ClientID,
749 ClientSecret: secret,
750 Scopes: rt.config.Scopes,
751 TokenURL: rt.config.TokenURL,
752 EndpointParams: mapToValues(rt.config.EndpointParams),
753 }
754
755 tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig)
756 if err != nil {
757 return nil, err
758 }
759
760 tlsTransport := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
761 return &http.Transport{
762 TLSClientConfig: tlsConfig,
763 Proxy: rt.config.ProxyConfig.Proxy(),
764 ProxyConnectHeader: rt.config.ProxyConfig.GetProxyConnectHeader(),
765 DisableKeepAlives: !rt.opts.keepAlivesEnabled,
766 MaxIdleConns: 20,
767 MaxIdleConnsPerHost: 1,
768 IdleConnTimeout: 10 * time.Second,
769 TLSHandshakeTimeout: 10 * time.Second,
770 ExpectContinueTimeout: 1 * time.Second,
771 }, nil
772 }
773
774 var t http.RoundTripper
775 if len(rt.config.TLSConfig.CAFile) == 0 {
776 t, _ = tlsTransport(tlsConfig)
777 } else {
778 t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.roundTripperSettings(), tlsTransport)
779 if err != nil {
780 return nil, err
781 }
782 }
783
784 if ua := req.UserAgent(); ua != "" {
785 t = NewUserAgentRoundTripper(ua, t)
786 }
787
788 client := &http.Client{Transport: t}
789 ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client)
790 tokenSource := config.TokenSource(ctx)
791
792 rt.mtx.Lock()
793 rt.secret = secret
794 rt.rt = &oauth2.Transport{
795 Base: rt.next,
796 Source: tokenSource,
797 }
798 if rt.client != nil {
799 rt.client.CloseIdleConnections()
800 }
801 rt.client = client
802 rt.mtx.Unlock()
803 }
804
805 rt.mtx.RLock()
806 currentRT := rt.rt
807 rt.mtx.RUnlock()
808 return currentRT.RoundTrip(req)
809 }
810
811 func (rt *oauth2RoundTripper) CloseIdleConnections() {
812 if rt.client != nil {
813 rt.client.CloseIdleConnections()
814 }
815 if ci, ok := rt.next.(closeIdler); ok {
816 ci.CloseIdleConnections()
817 }
818 }
819
820 func mapToValues(m map[string]string) url.Values {
821 v := url.Values{}
822 for name, value := range m {
823 v.Set(name, value)
824 }
825
826 return v
827 }
828
829
830
831 func cloneRequest(r *http.Request) *http.Request {
832
833 r2 := new(http.Request)
834 *r2 = *r
835
836 r2.Header = make(http.Header)
837 for k, s := range r.Header {
838 r2.Header[k] = s
839 }
840 return r2
841 }
842
843
844 func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
845 if err := cfg.Validate(); err != nil {
846 return nil, err
847 }
848
849 tlsConfig := &tls.Config{
850 InsecureSkipVerify: cfg.InsecureSkipVerify,
851 MinVersion: uint16(cfg.MinVersion),
852 MaxVersion: uint16(cfg.MaxVersion),
853 }
854
855 if cfg.MaxVersion != 0 && cfg.MinVersion != 0 {
856 if cfg.MaxVersion < cfg.MinVersion {
857 return nil, fmt.Errorf("tls_config.max_version must be greater than or equal to tls_config.min_version if both are specified")
858 }
859 }
860
861
862
863 if len(cfg.CA) > 0 {
864 if !updateRootCA(tlsConfig, []byte(cfg.CA)) {
865 return nil, fmt.Errorf("unable to use inline CA cert")
866 }
867 } else if len(cfg.CAFile) > 0 {
868 b, err := readCAFile(cfg.CAFile)
869 if err != nil {
870 return nil, err
871 }
872 if !updateRootCA(tlsConfig, b) {
873 return nil, fmt.Errorf("unable to use specified CA cert %s", cfg.CAFile)
874 }
875 }
876
877 if len(cfg.ServerName) > 0 {
878 tlsConfig.ServerName = cfg.ServerName
879 }
880
881
882 if cfg.usingClientCert() && cfg.usingClientKey() {
883
884 if _, err := cfg.getClientCertificate(nil); err != nil {
885 return nil, err
886 }
887 tlsConfig.GetClientCertificate = cfg.getClientCertificate
888 }
889
890 return tlsConfig, nil
891 }
892
893
894 type TLSConfig struct {
895
896 CA string `yaml:"ca,omitempty" json:"ca,omitempty"`
897
898 Cert string `yaml:"cert,omitempty" json:"cert,omitempty"`
899
900 Key Secret `yaml:"key,omitempty" json:"key,omitempty"`
901
902 CAFile string `yaml:"ca_file,omitempty" json:"ca_file,omitempty"`
903
904 CertFile string `yaml:"cert_file,omitempty" json:"cert_file,omitempty"`
905
906 KeyFile string `yaml:"key_file,omitempty" json:"key_file,omitempty"`
907
908 ServerName string `yaml:"server_name,omitempty" json:"server_name,omitempty"`
909
910 InsecureSkipVerify bool `yaml:"insecure_skip_verify" json:"insecure_skip_verify"`
911
912 MinVersion TLSVersion `yaml:"min_version,omitempty" json:"min_version,omitempty"`
913
914 MaxVersion TLSVersion `yaml:"max_version,omitempty" json:"max_version,omitempty"`
915 }
916
917
918 func (c *TLSConfig) SetDirectory(dir string) {
919 if c == nil {
920 return
921 }
922 c.CAFile = JoinDir(dir, c.CAFile)
923 c.CertFile = JoinDir(dir, c.CertFile)
924 c.KeyFile = JoinDir(dir, c.KeyFile)
925 }
926
927
928 func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
929 type plain TLSConfig
930 if err := unmarshal((*plain)(c)); err != nil {
931 return err
932 }
933 return c.Validate()
934 }
935
936
937
938
939 func (c *TLSConfig) Validate() error {
940 if len(c.CA) > 0 && len(c.CAFile) > 0 {
941 return fmt.Errorf("at most one of ca and ca_file must be configured")
942 }
943 if len(c.Cert) > 0 && len(c.CertFile) > 0 {
944 return fmt.Errorf("at most one of cert and cert_file must be configured")
945 }
946 if len(c.Key) > 0 && len(c.KeyFile) > 0 {
947 return fmt.Errorf("at most one of key and key_file must be configured")
948 }
949
950 if c.usingClientCert() && !c.usingClientKey() {
951 return fmt.Errorf("exactly one of key or key_file must be configured when a client certificate is configured")
952 } else if c.usingClientKey() && !c.usingClientCert() {
953 return fmt.Errorf("exactly one of cert or cert_file must be configured when a client key is configured")
954 }
955
956 return nil
957 }
958
959 func (c *TLSConfig) usingClientCert() bool {
960 return len(c.Cert) > 0 || len(c.CertFile) > 0
961 }
962
963 func (c *TLSConfig) usingClientKey() bool {
964 return len(c.Key) > 0 || len(c.KeyFile) > 0
965 }
966
967 func (c *TLSConfig) roundTripperSettings() TLSRoundTripperSettings {
968 return TLSRoundTripperSettings{
969 CA: c.CA,
970 CAFile: c.CAFile,
971 Cert: c.Cert,
972 CertFile: c.CertFile,
973 Key: string(c.Key),
974 KeyFile: c.KeyFile,
975 }
976 }
977
978
979 func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
980 var (
981 certData, keyData []byte
982 err error
983 )
984
985 if c.CertFile != "" {
986 certData, err = os.ReadFile(c.CertFile)
987 if err != nil {
988 return nil, fmt.Errorf("unable to read specified client cert (%s): %w", c.CertFile, err)
989 }
990 } else {
991 certData = []byte(c.Cert)
992 }
993
994 if c.KeyFile != "" {
995 keyData, err = os.ReadFile(c.KeyFile)
996 if err != nil {
997 return nil, fmt.Errorf("unable to read specified client key (%s): %w", c.KeyFile, err)
998 }
999 } else {
1000 keyData = []byte(c.Key)
1001 }
1002
1003 cert, err := tls.X509KeyPair(certData, keyData)
1004 if err != nil {
1005 return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %w", c.CertFile, c.KeyFile, err)
1006 }
1007
1008 return &cert, nil
1009 }
1010
1011
1012 func readCAFile(f string) ([]byte, error) {
1013 data, err := os.ReadFile(f)
1014 if err != nil {
1015 return nil, fmt.Errorf("unable to load specified CA cert %s: %w", f, err)
1016 }
1017 return data, nil
1018 }
1019
1020
1021 func updateRootCA(cfg *tls.Config, b []byte) bool {
1022 caCertPool := x509.NewCertPool()
1023 if !caCertPool.AppendCertsFromPEM(b) {
1024 return false
1025 }
1026 cfg.RootCAs = caCertPool
1027 return true
1028 }
1029
1030
1031
1032 type tlsRoundTripper struct {
1033 settings TLSRoundTripperSettings
1034
1035
1036 newRT func(*tls.Config) (http.RoundTripper, error)
1037
1038 mtx sync.RWMutex
1039 rt http.RoundTripper
1040 hashCAData []byte
1041 hashCertData []byte
1042 hashKeyData []byte
1043 tlsConfig *tls.Config
1044 }
1045
1046 type TLSRoundTripperSettings struct {
1047 CA, CAFile string
1048 Cert, CertFile string
1049 Key, KeyFile string
1050 }
1051
1052 func NewTLSRoundTripper(
1053 cfg *tls.Config,
1054 settings TLSRoundTripperSettings,
1055 newRT func(*tls.Config) (http.RoundTripper, error),
1056 ) (http.RoundTripper, error) {
1057 t := &tlsRoundTripper{
1058 settings: settings,
1059 newRT: newRT,
1060 tlsConfig: cfg,
1061 }
1062
1063 rt, err := t.newRT(t.tlsConfig)
1064 if err != nil {
1065 return nil, err
1066 }
1067 t.rt = rt
1068 _, t.hashCAData, t.hashCertData, t.hashKeyData, err = t.getTLSDataWithHash()
1069 if err != nil {
1070 return nil, err
1071 }
1072
1073 return t, nil
1074 }
1075
1076 func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, error) {
1077 var (
1078 caBytes, certBytes, keyBytes []byte
1079
1080 err error
1081 )
1082
1083 if t.settings.CAFile != "" {
1084 caBytes, err = os.ReadFile(t.settings.CAFile)
1085 if err != nil {
1086 return nil, nil, nil, nil, err
1087 }
1088 } else if t.settings.CA != "" {
1089 caBytes = []byte(t.settings.CA)
1090 }
1091
1092 if t.settings.CertFile != "" {
1093 certBytes, err = os.ReadFile(t.settings.CertFile)
1094 if err != nil {
1095 return nil, nil, nil, nil, err
1096 }
1097 } else if t.settings.Cert != "" {
1098 certBytes = []byte(t.settings.Cert)
1099 }
1100
1101 if t.settings.KeyFile != "" {
1102 keyBytes, err = os.ReadFile(t.settings.KeyFile)
1103 if err != nil {
1104 return nil, nil, nil, nil, err
1105 }
1106 } else if t.settings.Key != "" {
1107 keyBytes = []byte(t.settings.Key)
1108 }
1109
1110 var caHash, certHash, keyHash [32]byte
1111
1112 if len(caBytes) > 0 {
1113 caHash = sha256.Sum256(caBytes)
1114 }
1115 if len(certBytes) > 0 {
1116 certHash = sha256.Sum256(certBytes)
1117 }
1118 if len(keyBytes) > 0 {
1119 keyHash = sha256.Sum256(keyBytes)
1120 }
1121
1122 return caBytes, caHash[:], certHash[:], keyHash[:], nil
1123 }
1124
1125
1126 func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
1127 caData, caHash, certHash, keyHash, err := t.getTLSDataWithHash()
1128 if err != nil {
1129 return nil, err
1130 }
1131
1132 t.mtx.RLock()
1133 equal := bytes.Equal(caHash[:], t.hashCAData) &&
1134 bytes.Equal(certHash[:], t.hashCertData) &&
1135 bytes.Equal(keyHash[:], t.hashKeyData)
1136 rt := t.rt
1137 t.mtx.RUnlock()
1138 if equal {
1139
1140 return rt.RoundTrip(req)
1141 }
1142
1143
1144
1145
1146 tlsConfig := t.tlsConfig.Clone()
1147 if !updateRootCA(tlsConfig, caData) {
1148 return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CAFile)
1149 }
1150 rt, err = t.newRT(tlsConfig)
1151 if err != nil {
1152 return nil, err
1153 }
1154 t.CloseIdleConnections()
1155
1156 t.mtx.Lock()
1157 t.rt = rt
1158 t.hashCAData = caHash[:]
1159 t.hashCertData = certHash[:]
1160 t.hashKeyData = keyHash[:]
1161 t.mtx.Unlock()
1162
1163 return rt.RoundTrip(req)
1164 }
1165
1166 func (t *tlsRoundTripper) CloseIdleConnections() {
1167 t.mtx.RLock()
1168 defer t.mtx.RUnlock()
1169 if ci, ok := t.rt.(closeIdler); ok {
1170 ci.CloseIdleConnections()
1171 }
1172 }
1173
1174 type userAgentRoundTripper struct {
1175 userAgent string
1176 rt http.RoundTripper
1177 }
1178
1179 type hostRoundTripper struct {
1180 host string
1181 rt http.RoundTripper
1182 }
1183
1184
1185 func NewUserAgentRoundTripper(userAgent string, rt http.RoundTripper) http.RoundTripper {
1186 return &userAgentRoundTripper{userAgent, rt}
1187 }
1188
1189
1190 func NewHostRoundTripper(host string, rt http.RoundTripper) http.RoundTripper {
1191 return &hostRoundTripper{host, rt}
1192 }
1193
1194 func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
1195 req = cloneRequest(req)
1196 req.Header.Set("User-Agent", rt.userAgent)
1197 return rt.rt.RoundTrip(req)
1198 }
1199
1200 func (rt *userAgentRoundTripper) CloseIdleConnections() {
1201 if ci, ok := rt.rt.(closeIdler); ok {
1202 ci.CloseIdleConnections()
1203 }
1204 }
1205
1206 func (rt *hostRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
1207 req = cloneRequest(req)
1208 req.Host = rt.host
1209 req.Header.Set("Host", rt.host)
1210 return rt.rt.RoundTrip(req)
1211 }
1212
1213 func (rt *hostRoundTripper) CloseIdleConnections() {
1214 if ci, ok := rt.rt.(closeIdler); ok {
1215 ci.CloseIdleConnections()
1216 }
1217 }
1218
1219 func (c HTTPClientConfig) String() string {
1220 b, err := yaml.Marshal(c)
1221 if err != nil {
1222 return fmt.Sprintf("<error creating http client config string: %s>", err)
1223 }
1224 return string(b)
1225 }
1226
1227 type ProxyConfig struct {
1228
1229 ProxyURL URL `yaml:"proxy_url,omitempty" json:"proxy_url,omitempty"`
1230
1231 NoProxy string `yaml:"no_proxy,omitempty" json:"no_proxy,omitempty"`
1232
1233
1234 ProxyFromEnvironment bool `yaml:"proxy_from_environment,omitempty" json:"proxy_from_environment,omitempty"`
1235
1236
1237
1238
1239 ProxyConnectHeader Header `yaml:"proxy_connect_header,omitempty" json:"proxy_connect_header,omitempty"`
1240
1241 proxyFunc func(*http.Request) (*url.URL, error)
1242 }
1243
1244
1245 func (c *ProxyConfig) Validate() error {
1246 if len(c.ProxyConnectHeader) > 0 && (!c.ProxyFromEnvironment && (c.ProxyURL.URL == nil || c.ProxyURL.String() == "")) {
1247 return fmt.Errorf("if proxy_connect_header is configured, proxy_url or proxy_from_environment must also be configured")
1248 }
1249 if c.ProxyFromEnvironment && c.ProxyURL.URL != nil && c.ProxyURL.String() != "" {
1250 return fmt.Errorf("if proxy_from_environment is configured, proxy_url must not be configured")
1251 }
1252 if c.ProxyFromEnvironment && c.NoProxy != "" {
1253 return fmt.Errorf("if proxy_from_environment is configured, no_proxy must not be configured")
1254 }
1255 if c.ProxyURL.URL == nil && c.NoProxy != "" {
1256 return fmt.Errorf("if no_proxy is configured, proxy_url must also be configured")
1257 }
1258 return nil
1259 }
1260
1261
1262 func (c *ProxyConfig) Proxy() (fn func(*http.Request) (*url.URL, error)) {
1263 if c == nil {
1264 return nil
1265 }
1266 defer func() {
1267 fn = c.proxyFunc
1268 }()
1269 if c.proxyFunc != nil {
1270 return
1271 }
1272 if c.ProxyFromEnvironment {
1273 proxyFn := httpproxy.FromEnvironment().ProxyFunc()
1274 c.proxyFunc = func(req *http.Request) (*url.URL, error) {
1275 return proxyFn(req.URL)
1276 }
1277 return
1278 }
1279 if c.ProxyURL.URL != nil && c.ProxyURL.URL.String() != "" {
1280 if c.NoProxy == "" {
1281 c.proxyFunc = http.ProxyURL(c.ProxyURL.URL)
1282 return
1283 }
1284 proxy := &httpproxy.Config{
1285 HTTPProxy: c.ProxyURL.String(),
1286 HTTPSProxy: c.ProxyURL.String(),
1287 NoProxy: c.NoProxy,
1288 }
1289 proxyFn := proxy.ProxyFunc()
1290 c.proxyFunc = func(req *http.Request) (*url.URL, error) {
1291 return proxyFn(req.URL)
1292 }
1293 }
1294 return
1295 }
1296
1297
1298 func (c *ProxyConfig) GetProxyConnectHeader() http.Header {
1299 return c.ProxyConnectHeader.HTTPHeader()
1300 }
1301
View as plain text