1
2
3
4
5
6
7
8
9
10
11
12
13
14 package main
15
16 import (
17 "bufio"
18 "bytes"
19 "flag"
20 "fmt"
21 "io"
22 "log"
23 "net/http"
24 "net/textproto"
25 "net/url"
26 "os"
27 "os/exec"
28 "path/filepath"
29 "strings"
30 )
31
32 var v = flag.Bool("v", false, "if true, log GOAUTH responses to stderr")
33
34 func main() {
35 log.SetFlags(log.LstdFlags | log.Lshortfile)
36 flag.Parse()
37 args := flag.Args()
38 if len(args) != 1 {
39 log.Fatalf("usage: [GOAUTH=CMD...] %s URL", filepath.Base(os.Args[0]))
40 }
41
42 resp := try(args[0], nil)
43 if resp.StatusCode == http.StatusOK {
44 return
45 }
46
47 resp = try(args[0], resp)
48 if resp.StatusCode != http.StatusOK {
49 os.Exit(1)
50 }
51 }
52
53 func try(url string, prev *http.Response) *http.Response {
54 req := new(http.Request)
55 if prev != nil {
56 *req = *prev.Request
57 } else {
58 var err error
59 req, err = http.NewRequest("HEAD", os.Args[1], nil)
60 if err != nil {
61 log.Fatal(err)
62 }
63 }
64
65 goauth:
66 for _, argList := range strings.Split(os.Getenv("GOAUTH"), ";") {
67
68
69 args := strings.Split(argList, " ")
70 if len(args) == 0 || args[0] == "" {
71 log.Fatalf("invalid or empty command in GOAUTH")
72 }
73
74 creds, err := getCreds(args, prev)
75 if err != nil {
76 log.Fatal(err)
77 }
78 for _, c := range creds {
79 if c.Apply(req) {
80 fmt.Fprintf(os.Stderr, "# request to %s\n", req.URL)
81 fmt.Fprintf(os.Stderr, "%s %s %s\n", req.Method, req.URL, req.Proto)
82 req.Header.Write(os.Stderr)
83 fmt.Fprintln(os.Stderr)
84 break goauth
85 }
86 }
87 }
88
89 resp, err := http.DefaultClient.Do(req)
90 if err != nil {
91 log.Fatal(err)
92 }
93 defer resp.Body.Close()
94
95 if resp.StatusCode != http.StatusOK && resp.StatusCode < 400 || resp.StatusCode > 500 {
96 log.Fatalf("unexpected status: %v", resp.Status)
97 }
98
99 fmt.Fprintf(os.Stderr, "# response from %s\n", resp.Request.URL)
100 formatHead(os.Stderr, resp)
101 return resp
102 }
103
104 func formatHead(out io.Writer, resp *http.Response) {
105 fmt.Fprintf(out, "%s %s\n", resp.Proto, resp.Status)
106 if err := resp.Header.Write(out); err != nil {
107 log.Fatal(err)
108 }
109 fmt.Fprintln(out)
110 }
111
112 type Cred struct {
113 URLPrefixes []*url.URL
114 Header http.Header
115 }
116
117 func (c Cred) Apply(req *http.Request) bool {
118 if req.URL == nil {
119 return false
120 }
121 ok := false
122 for _, prefix := range c.URLPrefixes {
123 if prefix.Host == req.URL.Host &&
124 (req.URL.Path == prefix.Path ||
125 (strings.HasPrefix(req.URL.Path, prefix.Path) &&
126 (strings.HasSuffix(prefix.Path, "/") ||
127 req.URL.Path[len(prefix.Path)] == '/'))) {
128 ok = true
129 break
130 }
131 }
132 if !ok {
133 return false
134 }
135
136 for k, vs := range c.Header {
137 req.Header.Del(k)
138 for _, v := range vs {
139 req.Header.Add(k, v)
140 }
141 }
142 return true
143 }
144
145 func (c Cred) String() string {
146 var buf strings.Builder
147 for _, u := range c.URLPrefixes {
148 fmt.Fprintln(&buf, u)
149 }
150 buf.WriteString("\n")
151 c.Header.Write(&buf)
152 buf.WriteString("\n")
153 return buf.String()
154 }
155
156 func getCreds(args []string, resp *http.Response) ([]Cred, error) {
157 cmd := exec.Command(args[0], args[1:]...)
158 cmd.Stderr = os.Stderr
159
160 if resp != nil {
161 u := *resp.Request.URL
162 u.RawQuery = ""
163 cmd.Args = append(cmd.Args, u.String())
164 }
165
166 var head strings.Builder
167 if resp != nil {
168 formatHead(&head, resp)
169 }
170 cmd.Stdin = strings.NewReader(head.String())
171
172 fmt.Fprintf(os.Stderr, "# %s\n", strings.Join(cmd.Args, " "))
173 out, err := cmd.Output()
174 if err != nil {
175 return nil, fmt.Errorf("%s: %v", strings.Join(cmd.Args, " "), err)
176 }
177 os.Stderr.Write(out)
178 os.Stderr.WriteString("\n")
179
180 var creds []Cred
181 r := textproto.NewReader(bufio.NewReader(bytes.NewReader(out)))
182 line := 0
183 readLoop:
184 for {
185 var prefixes []*url.URL
186 for {
187 prefix, err := r.ReadLine()
188 if err == io.EOF {
189 if len(prefixes) > 0 {
190 return nil, fmt.Errorf("line %d: %v", line, io.ErrUnexpectedEOF)
191 }
192 break readLoop
193 }
194 line++
195
196 if prefix == "" {
197 if len(prefixes) == 0 {
198 return nil, fmt.Errorf("line %d: unexpected newline", line)
199 }
200 break
201 }
202 u, err := url.Parse(prefix)
203 if err != nil {
204 return nil, fmt.Errorf("line %d: malformed URL: %v", line, err)
205 }
206 if u.Scheme != "https" {
207 return nil, fmt.Errorf("line %d: non-HTTPS URL %q", line, prefix)
208 }
209 if len(u.RawQuery) > 0 {
210 return nil, fmt.Errorf("line %d: unexpected query string in URL %q", line, prefix)
211 }
212 if len(u.Fragment) > 0 {
213 return nil, fmt.Errorf("line %d: unexpected fragment in URL %q", line, prefix)
214 }
215 prefixes = append(prefixes, u)
216 }
217
218 header, err := r.ReadMIMEHeader()
219 if err != nil {
220 return nil, fmt.Errorf("headers at line %d: %v", line, err)
221 }
222 if len(header) > 0 {
223 creds = append(creds, Cred{
224 URLPrefixes: prefixes,
225 Header: http.Header(header),
226 })
227 }
228 }
229
230 return creds, nil
231 }
232
View as plain text