...
1
2
3
4
5 package resty
6
7 import (
8 "errors"
9 "fmt"
10 "net"
11 "net/http"
12 "strings"
13 )
14
15 type (
16
17
18
19
20
21 RedirectPolicy interface {
22 Apply(req *http.Request, via []*http.Request) error
23 }
24
25
26
27 RedirectPolicyFunc func(*http.Request, []*http.Request) error
28 )
29
30
31 func (f RedirectPolicyFunc) Apply(req *http.Request, via []*http.Request) error {
32 return f(req, via)
33 }
34
35
36
37 func NoRedirectPolicy() RedirectPolicy {
38 return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
39 return errors.New("auto redirect is disabled")
40 })
41 }
42
43
44
45 func FlexibleRedirectPolicy(noOfRedirect int) RedirectPolicy {
46 return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
47 if len(via) >= noOfRedirect {
48 return fmt.Errorf("stopped after %d redirects", noOfRedirect)
49 }
50 checkHostAndAddHeaders(req, via[0])
51 return nil
52 })
53 }
54
55
56
57
58 func DomainCheckRedirectPolicy(hostnames ...string) RedirectPolicy {
59 hosts := make(map[string]bool)
60 for _, h := range hostnames {
61 hosts[strings.ToLower(h)] = true
62 }
63
64 fn := RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
65 if ok := hosts[getHostname(req.URL.Host)]; !ok {
66 return errors.New("redirect is not allowed as per DomainCheckRedirectPolicy")
67 }
68
69 return nil
70 })
71
72 return fn
73 }
74
75
76
77
78
79 func getHostname(host string) (hostname string) {
80 if strings.Index(host, ":") > 0 {
81 host, _, _ = net.SplitHostPort(host)
82 }
83 hostname = strings.ToLower(host)
84 return
85 }
86
87
88
89
90
91 func checkHostAndAddHeaders(cur *http.Request, pre *http.Request) {
92 curHostname := getHostname(cur.URL.Host)
93 preHostname := getHostname(pre.URL.Host)
94 if strings.EqualFold(curHostname, preHostname) {
95 for key, val := range pre.Header {
96 cur.Header[key] = val
97 }
98 } else {
99 cur.Header.Set(hdrUserAgentKey, hdrUserAgentValue)
100 }
101 }
102
View as plain text