1
16
17 package net
18
19 import (
20 "bytes"
21 "context"
22 "crypto/tls"
23 "errors"
24 "fmt"
25 "io"
26 "mime"
27 "net"
28 "net/http"
29 "net/url"
30 "os"
31 "path"
32 "regexp"
33 "strconv"
34 "strings"
35 "time"
36 "unicode"
37 "unicode/utf8"
38
39 "golang.org/x/net/http2"
40 "k8s.io/klog/v2"
41 netutils "k8s.io/utils/net"
42 )
43
44
45
46 func JoinPreservingTrailingSlash(elem ...string) string {
47
48 result := path.Join(elem...)
49
50
51 for i := len(elem) - 1; i >= 0; i-- {
52 if len(elem[i]) > 0 {
53
54 if strings.HasSuffix(elem[i], "/") && !strings.HasSuffix(result, "/") {
55 result += "/"
56 }
57 break
58 }
59 }
60
61 return result
62 }
63
64
65 func IsTimeout(err error) bool {
66 var neterr net.Error
67 if errors.As(err, &neterr) {
68 return neterr != nil && neterr.Timeout()
69 }
70 return false
71 }
72
73
74
75
76
77
78
79 func IsProbableEOF(err error) bool {
80 if err == nil {
81 return false
82 }
83 var uerr *url.Error
84 if errors.As(err, &uerr) {
85 err = uerr.Err
86 }
87 msg := err.Error()
88 switch {
89 case err == io.EOF:
90 return true
91 case err == io.ErrUnexpectedEOF:
92 return true
93 case msg == "http: can't write HTTP request on broken connection":
94 return true
95 case strings.Contains(msg, "http2: server sent GOAWAY and closed the connection"):
96 return true
97 case strings.Contains(msg, "connection reset by peer"):
98 return true
99 case strings.Contains(strings.ToLower(msg), "use of closed network connection"):
100 return true
101 }
102 return false
103 }
104
105 var defaultTransport = http.DefaultTransport.(*http.Transport)
106
107
108
109 func SetOldTransportDefaults(t *http.Transport) *http.Transport {
110 if t.Proxy == nil || isDefault(t.Proxy) {
111
112
113 t.Proxy = NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
114 }
115
116
117 if t.DialContext == nil && t.Dial == nil {
118 t.DialContext = defaultTransport.DialContext
119 }
120 if t.TLSHandshakeTimeout == 0 {
121 t.TLSHandshakeTimeout = defaultTransport.TLSHandshakeTimeout
122 }
123 if t.IdleConnTimeout == 0 {
124 t.IdleConnTimeout = defaultTransport.IdleConnTimeout
125 }
126 return t
127 }
128
129
130
131 func SetTransportDefaults(t *http.Transport) *http.Transport {
132 t = SetOldTransportDefaults(t)
133
134 if s := os.Getenv("DISABLE_HTTP2"); len(s) > 0 {
135 klog.Info("HTTP2 has been explicitly disabled")
136 } else if allowsHTTP2(t) {
137 if err := configureHTTP2Transport(t); err != nil {
138 klog.Warningf("Transport failed http2 configuration: %v", err)
139 }
140 }
141 return t
142 }
143
144 func readIdleTimeoutSeconds() int {
145 ret := 30
146
147
148 if s := os.Getenv("HTTP2_READ_IDLE_TIMEOUT_SECONDS"); len(s) > 0 {
149 i, err := strconv.Atoi(s)
150 if err != nil {
151 klog.Warningf("Illegal HTTP2_READ_IDLE_TIMEOUT_SECONDS(%q): %v."+
152 " Default value %d is used", s, err, ret)
153 return ret
154 }
155 ret = i
156 }
157 return ret
158 }
159
160 func pingTimeoutSeconds() int {
161 ret := 15
162 if s := os.Getenv("HTTP2_PING_TIMEOUT_SECONDS"); len(s) > 0 {
163 i, err := strconv.Atoi(s)
164 if err != nil {
165 klog.Warningf("Illegal HTTP2_PING_TIMEOUT_SECONDS(%q): %v."+
166 " Default value %d is used", s, err, ret)
167 return ret
168 }
169 ret = i
170 }
171 return ret
172 }
173
174 func configureHTTP2Transport(t *http.Transport) error {
175 t2, err := http2.ConfigureTransports(t)
176 if err != nil {
177 return err
178 }
179
180
181
182
183
184
185
186
187 t2.ReadIdleTimeout = time.Duration(readIdleTimeoutSeconds()) * time.Second
188 t2.PingTimeout = time.Duration(pingTimeoutSeconds()) * time.Second
189 return nil
190 }
191
192 func allowsHTTP2(t *http.Transport) bool {
193 if t.TLSClientConfig == nil || len(t.TLSClientConfig.NextProtos) == 0 {
194
195 return true
196 }
197 for _, p := range t.TLSClientConfig.NextProtos {
198 if p == http2.NextProtoTLS {
199
200 return true
201 }
202 }
203
204 return false
205 }
206
207 type RoundTripperWrapper interface {
208 http.RoundTripper
209 WrappedRoundTripper() http.RoundTripper
210 }
211
212 type DialFunc func(ctx context.Context, net, addr string) (net.Conn, error)
213
214 func DialerFor(transport http.RoundTripper) (DialFunc, error) {
215 if transport == nil {
216 return nil, nil
217 }
218
219 switch transport := transport.(type) {
220 case *http.Transport:
221
222 if transport.DialContext != nil {
223 return transport.DialContext, nil
224 }
225
226 if transport.Dial != nil {
227 return func(ctx context.Context, net, addr string) (net.Conn, error) {
228 return transport.Dial(net, addr)
229 }, nil
230 }
231
232 return nil, nil
233 case RoundTripperWrapper:
234 return DialerFor(transport.WrappedRoundTripper())
235 default:
236 return nil, fmt.Errorf("unknown transport type: %T", transport)
237 }
238 }
239
240
241
242
243
244
245 func CloseIdleConnectionsFor(transport http.RoundTripper) {
246 if transport == nil {
247 return
248 }
249 type closeIdler interface {
250 CloseIdleConnections()
251 }
252
253 switch transport := transport.(type) {
254 case closeIdler:
255 transport.CloseIdleConnections()
256 case RoundTripperWrapper:
257 CloseIdleConnectionsFor(transport.WrappedRoundTripper())
258 default:
259 klog.Warningf("unknown transport type: %T", transport)
260 }
261 }
262
263 type TLSClientConfigHolder interface {
264 TLSClientConfig() *tls.Config
265 }
266
267 func TLSClientConfig(transport http.RoundTripper) (*tls.Config, error) {
268 if transport == nil {
269 return nil, nil
270 }
271
272 switch transport := transport.(type) {
273 case *http.Transport:
274 return transport.TLSClientConfig, nil
275 case TLSClientConfigHolder:
276 return transport.TLSClientConfig(), nil
277 case RoundTripperWrapper:
278 return TLSClientConfig(transport.WrappedRoundTripper())
279 default:
280 return nil, fmt.Errorf("unknown transport type: %T", transport)
281 }
282 }
283
284 func FormatURL(scheme string, host string, port int, path string) *url.URL {
285 return &url.URL{
286 Scheme: scheme,
287 Host: net.JoinHostPort(host, strconv.Itoa(port)),
288 Path: path,
289 }
290 }
291
292 func GetHTTPClient(req *http.Request) string {
293 if ua := req.UserAgent(); len(ua) != 0 {
294 return ua
295 }
296 return "unknown"
297 }
298
299
300
301
302
303
304 func SourceIPs(req *http.Request) []net.IP {
305 var srcIPs []net.IP
306
307 hdr := req.Header
308
309 hdrForwardedFor := hdr.Get("X-Forwarded-For")
310 if hdrForwardedFor != "" {
311
312
313 parts := strings.Split(hdrForwardedFor, ",")
314 for _, part := range parts {
315 ip := netutils.ParseIPSloppy(strings.TrimSpace(part))
316 if ip != nil {
317 srcIPs = append(srcIPs, ip)
318 }
319 }
320 }
321
322
323 hdrRealIp := hdr.Get("X-Real-Ip")
324 if hdrRealIp != "" {
325 ip := netutils.ParseIPSloppy(hdrRealIp)
326
327 if ip != nil && !containsIP(srcIPs, ip) {
328 srcIPs = append(srcIPs, ip)
329 }
330 }
331
332
333 var remoteIP net.IP
334
335 host, _, err := net.SplitHostPort(req.RemoteAddr)
336 if err == nil {
337 remoteIP = netutils.ParseIPSloppy(host)
338 }
339
340 if remoteIP == nil {
341 remoteIP = netutils.ParseIPSloppy(req.RemoteAddr)
342 }
343
344
345 if remoteIP != nil && (len(srcIPs) == 0 || !remoteIP.Equal(srcIPs[len(srcIPs)-1])) {
346 srcIPs = append(srcIPs, remoteIP)
347 }
348
349 return srcIPs
350 }
351
352
353 func containsIP(ips []net.IP, ip net.IP) bool {
354 for _, v := range ips {
355 if v.Equal(ip) {
356 return true
357 }
358 }
359 return false
360 }
361
362
363
364
365 func GetClientIP(req *http.Request) net.IP {
366 ips := SourceIPs(req)
367 if len(ips) == 0 {
368 return nil
369 }
370 return ips[0]
371 }
372
373
374
375 func AppendForwardedForHeader(req *http.Request) {
376
377 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
378
379
380
381 if prior, ok := req.Header["X-Forwarded-For"]; ok {
382 clientIP = strings.Join(prior, ", ") + ", " + clientIP
383 }
384 req.Header.Set("X-Forwarded-For", clientIP)
385 }
386 }
387
388 var defaultProxyFuncPointer = fmt.Sprintf("%p", http.ProxyFromEnvironment)
389
390
391 func isDefault(transportProxier func(*http.Request) (*url.URL, error)) bool {
392 transportProxierPointer := fmt.Sprintf("%p", transportProxier)
393 return transportProxierPointer == defaultProxyFuncPointer
394 }
395
396
397
398 func NewProxierWithNoProxyCIDR(delegate func(req *http.Request) (*url.URL, error)) func(req *http.Request) (*url.URL, error) {
399
400 noProxyEnv := os.Getenv("NO_PROXY")
401 if noProxyEnv == "" {
402 noProxyEnv = os.Getenv("no_proxy")
403 }
404 noProxyRules := strings.Split(noProxyEnv, ",")
405
406 cidrs := []*net.IPNet{}
407 for _, noProxyRule := range noProxyRules {
408 _, cidr, _ := netutils.ParseCIDRSloppy(noProxyRule)
409 if cidr != nil {
410 cidrs = append(cidrs, cidr)
411 }
412 }
413
414 if len(cidrs) == 0 {
415 return delegate
416 }
417
418 return func(req *http.Request) (*url.URL, error) {
419 ip := netutils.ParseIPSloppy(req.URL.Hostname())
420 if ip == nil {
421 return delegate(req)
422 }
423
424 for _, cidr := range cidrs {
425 if cidr.Contains(ip) {
426 return nil, nil
427 }
428 }
429
430 return delegate(req)
431 }
432 }
433
434
435 type DialerFunc func(req *http.Request) (net.Conn, error)
436
437 func (fn DialerFunc) Dial(req *http.Request) (net.Conn, error) {
438 return fn(req)
439 }
440
441
442 type Dialer interface {
443
444
445 Dial(req *http.Request) (net.Conn, error)
446 }
447
448
449 func CloneRequest(req *http.Request) *http.Request {
450 r := new(http.Request)
451
452
453 *r = *req
454
455
456 r.Header = CloneHeader(req.Header)
457
458 return r
459 }
460
461
462 func CloneHeader(in http.Header) http.Header {
463 out := make(http.Header, len(in))
464 for key, values := range in {
465 newValues := make([]string, len(values))
466 copy(newValues, values)
467 out[key] = newValues
468 }
469 return out
470 }
471
472
473 type WarningHeader struct {
474
475 Code int
476
477
478 Agent string
479
480 Text string
481 }
482
483
484
485
486
487 func ParseWarningHeaders(headers []string) ([]WarningHeader, []error) {
488 var (
489 results []WarningHeader
490 errs []error
491 )
492 for _, header := range headers {
493 for len(header) > 0 {
494 result, remainder, err := ParseWarningHeader(header)
495 if err != nil {
496 errs = append(errs, err)
497 break
498 }
499 results = append(results, result)
500 header = remainder
501 }
502 }
503 return results, errs
504 }
505
506 var (
507 codeMatcher = regexp.MustCompile(`^[0-9]{3}$`)
508 wordDecoder = &mime.WordDecoder{}
509 )
510
511
512
513
514 func ParseWarningHeader(header string) (result WarningHeader, remainder string, err error) {
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538 header = strings.TrimSpace(header)
539
540 parts := strings.SplitN(header, " ", 3)
541 if len(parts) != 3 {
542 return WarningHeader{}, "", errors.New("invalid warning header: fewer than 3 segments")
543 }
544 code, agent, textDateRemainder := parts[0], parts[1], parts[2]
545
546
547 if !codeMatcher.Match([]byte(code)) {
548 return WarningHeader{}, "", errors.New("invalid warning header: code segment is not 3 digits between 100-299")
549 }
550 codeInt, _ := strconv.ParseInt(code, 10, 64)
551
552
553 if len(agent) == 0 {
554 return WarningHeader{}, "", errors.New("invalid warning header: empty agent segment")
555 }
556 if !utf8.ValidString(agent) || hasAnyRunes(agent, unicode.IsControl) {
557 return WarningHeader{}, "", errors.New("invalid warning header: invalid agent")
558 }
559
560
561 if len(textDateRemainder) == 0 {
562 return WarningHeader{}, "", errors.New("invalid warning header: empty text segment")
563 }
564
565
566 text, dateAndRemainder, err := parseQuotedString(textDateRemainder)
567 if err != nil {
568 return WarningHeader{}, "", fmt.Errorf("invalid warning header: %v", err)
569 }
570
571 if decodedText, err := wordDecoder.DecodeHeader(text); err == nil {
572 text = decodedText
573 }
574 if !utf8.ValidString(text) || hasAnyRunes(text, unicode.IsControl) {
575 return WarningHeader{}, "", errors.New("invalid warning header: invalid text")
576 }
577 result = WarningHeader{Code: int(codeInt), Agent: agent, Text: text}
578
579 if len(dateAndRemainder) > 0 {
580 if dateAndRemainder[0] == '"' {
581
582 foundEndQuote := false
583 for i := 1; i < len(dateAndRemainder); i++ {
584 if dateAndRemainder[i] == '"' {
585 foundEndQuote = true
586 remainder = strings.TrimSpace(dateAndRemainder[i+1:])
587 break
588 }
589 }
590 if !foundEndQuote {
591 return WarningHeader{}, "", errors.New("invalid warning header: unterminated date segment")
592 }
593 } else {
594 remainder = dateAndRemainder
595 }
596 }
597 if len(remainder) > 0 {
598 if remainder[0] == ',' {
599
600 remainder = strings.TrimSpace(remainder[1:])
601 } else {
602 return WarningHeader{}, "", errors.New("invalid warning header: unexpected token after warn-date")
603 }
604 }
605
606 return result, remainder, nil
607 }
608
609 func parseQuotedString(quotedString string) (string, string, error) {
610 if len(quotedString) == 0 {
611 return "", "", errors.New("invalid quoted string: 0-length")
612 }
613
614 if quotedString[0] != '"' {
615 return "", "", errors.New("invalid quoted string: missing initial quote")
616 }
617
618 quotedString = quotedString[1:]
619 var remainder string
620 escaping := false
621 closedQuote := false
622 result := &strings.Builder{}
623 loop:
624 for i := 0; i < len(quotedString); i++ {
625 b := quotedString[i]
626 switch b {
627 case '"':
628 if escaping {
629 result.WriteByte(b)
630 escaping = false
631 } else {
632 closedQuote = true
633 remainder = strings.TrimSpace(quotedString[i+1:])
634 break loop
635 }
636 case '\\':
637 if escaping {
638 result.WriteByte(b)
639 escaping = false
640 } else {
641 escaping = true
642 }
643 default:
644 result.WriteByte(b)
645 escaping = false
646 }
647 }
648
649 if !closedQuote {
650 return "", "", errors.New("invalid quoted string: missing closing quote")
651 }
652 return result.String(), remainder, nil
653 }
654
655 func NewWarningHeader(code int, agent, text string) (string, error) {
656 if code < 0 || code > 999 {
657 return "", errors.New("code must be between 0 and 999")
658 }
659 if len(agent) == 0 {
660 agent = "-"
661 } else if !utf8.ValidString(agent) || strings.ContainsAny(agent, `\"`) || hasAnyRunes(agent, unicode.IsSpace, unicode.IsControl) {
662 return "", errors.New("agent must be valid UTF-8 and must not contain spaces, quotes, backslashes, or control characters")
663 }
664 if !utf8.ValidString(text) || hasAnyRunes(text, unicode.IsControl) {
665 return "", errors.New("text must be valid UTF-8 and must not contain control characters")
666 }
667 return fmt.Sprintf("%03d %s %s", code, agent, makeQuotedString(text)), nil
668 }
669
670 func hasAnyRunes(s string, runeCheckers ...func(rune) bool) bool {
671 for _, r := range s {
672 for _, checker := range runeCheckers {
673 if checker(r) {
674 return true
675 }
676 }
677 }
678 return false
679 }
680
681 func makeQuotedString(s string) string {
682 result := &bytes.Buffer{}
683
684 result.WriteRune('"')
685 for _, c := range s {
686 switch c {
687 case '"', '\\':
688
689 result.WriteRune('\\')
690 result.WriteRune(c)
691 default:
692
693 result.WriteRune(c)
694 }
695 }
696
697 result.WriteRune('"')
698 return result.String()
699 }
700
View as plain text