1 package main
2
3 import (
4 "bytes"
5 "fmt"
6 "go/format"
7 "io/ioutil"
8 "os"
9 "strings"
10 )
11
12 type Build struct {
13 Suffix string
14 Tags string
15 Interfaces Interfaces
16 }
17
18 func (b *Build) MustBuild() {
19 prefix := "wrap_generated_"
20 b.Implementation().MustWriteFile(prefix + b.Suffix + ".go")
21 b.Tests().MustWriteFile(prefix + b.Suffix + "_test.go")
22 }
23
24 func (b *Build) writeHeader(g *Generator) {
25 g.Printf(`
26 // +build %s
27 // Code generated by "httpsnoop/codegen"; DO NOT EDIT.
28
29 package httpsnoop
30
31 `, b.Tags)
32 }
33
34 func (b *Build) Implementation() *Generator {
35 ifaces := b.Interfaces
36
37 subIfaces := ifaces[1:]
38
39 var g Generator
40
41 b.writeHeader(&g)
42 g.Printf("import (\n")
43 g.Printf(`"net/http"` + "\n")
44 g.Printf(`"io"` + "\n")
45 g.Printf(`"net"` + "\n")
46 g.Printf(`"bufio"` + "\n")
47 g.Printf(")\n")
48 g.Printf("\n")
49
50
51 for _, iface := range ifaces {
52 for _, fn := range iface.Funcs {
53 g.Printf("// %s is part of the %s interface.\n", fn.Type(), iface.Name)
54 g.Printf("type %s func(%s) (%s)\n", fn.Type(), fn.Args, fn.Returns)
55 g.Printf("\n")
56 }
57 }
58
59
60 g.Printf(`
61 // Hooks defines a set of method interceptors for methods included in
62 // http.ResponseWriter as well as some others. You can think of them as
63 // middleware for the function calls they target. See Wrap for more details.
64 type Hooks struct {
65 `)
66 for _, iface := range ifaces {
67 for _, fn := range iface.Funcs {
68 g.Printf("%s func(%s) %s\n", fn.Name, fn.Type(), fn.Type())
69 }
70 }
71 g.Printf("}\n")
72
73
74 docList := make([]string, len(subIfaces))
75 for i, iface := range subIfaces {
76 docList[i] = "// - " + iface.Name
77 }
78 g.Printf(`
79 // Wrap returns a wrapped version of w that provides the exact same interface
80 // as w. Specifically if w implements any combination of:
81 //
82 %s
83 //
84 // The wrapped version will implement the exact same combination. If no hooks
85 // are set, the wrapped version also behaves exactly as w. Hooks targeting
86 // methods not supported by w are ignored. Any other hooks will intercept the
87 // method they target and may modify the call's arguments and/or return values.
88 // The CaptureMetrics implementation serves as a working example for how the
89 // hooks can be used.
90 `, strings.Join(docList, "\n"))
91 g.Printf("func Wrap(w http.ResponseWriter, hooks Hooks) http.ResponseWriter {\n")
92 g.Printf("rw := &rw{w: w, h: hooks}\n")
93 for i, iface := range subIfaces {
94 g.Printf("_, i%d := w.(%s)\n", i, iface.Name)
95 }
96 g.Printf("switch {\n")
97 combinations := 1 << uint(len(subIfaces))
98 for i := 0; i < combinations; i++ {
99 conditions := make([]string, len(subIfaces))
100 fields := make([]string, 0, len(subIfaces))
101 fields = append(fields, "Unwrapper", "http.ResponseWriter")
102 for j, iface := range subIfaces {
103 ok := i&(1<<uint(len(subIfaces)-j-1)) > 0
104 if !ok {
105 conditions[j] = "!"
106 } else {
107 fields = append(fields, iface.Name)
108 }
109 conditions[j] += fmt.Sprintf("i%d", j)
110 }
111 values := make([]string, len(fields))
112 for i, _ := range fields {
113 values[i] = "rw"
114 }
115 g.Printf("// combination %d/%d\n", i+1, combinations)
116 g.Printf("case %s:\n", strings.Join(conditions, "&&"))
117 fieldsS, valuesS := strings.Join(fields, "\n"), strings.Join(values, ",")
118 g.Printf("return struct{\n%s\n}{%s}\n", fieldsS, valuesS)
119 }
120 g.Printf("}\n")
121 g.Printf("panic(\"unreachable\")")
122 g.Printf("}\n")
123
124
125 g.Printf(`
126 type rw struct {
127 w http.ResponseWriter
128 h Hooks
129 }
130
131 func (w *rw) Unwrap() http.ResponseWriter {
132 return w.w
133 }
134
135 `)
136 for _, iface := range ifaces {
137 for _, fn := range iface.Funcs {
138 g.Printf("func (w *rw) %s(%s) (%s) {\n", fn.Name, fn.Args, fn.Returns)
139 g.Printf("f := w.w.(%s).%s\n", iface.Name, fn.Name)
140 g.Printf("if w.h.%s != nil {\n", fn.Name)
141 g.Printf("f = w.h.%s(f)\n", fn.Name)
142 g.Printf("}\n")
143 if fn.Returns != "" {
144 g.Printf("return ")
145 }
146 g.Printf("f(%s)\n", fn.Args.Names())
147 g.Printf("}\n")
148 g.Printf("\n")
149 }
150 }
151 g.Printf(`
152 type Unwrapper interface {
153 Unwrap() http.ResponseWriter
154 }
155
156 // Unwrap returns the underlying http.ResponseWriter from within zero or more
157 // layers of httpsnoop wrappers.
158 func Unwrap(w http.ResponseWriter) http.ResponseWriter {
159 if rw, ok := w.(Unwrapper); ok {
160 // recurse until rw.Unwrap() returns a non-Unwrapper
161 return Unwrap(rw.Unwrap())
162 } else {
163 return w
164 }
165 }
166 `)
167 return &g
168 }
169
170 func (b *Build) Tests() *Generator {
171 ifaces := b.Interfaces
172
173
174 subIfaces := ifaces[1:]
175
176 var g Generator
177
178 b.writeHeader(&g)
179 g.Printf("import (\n")
180 g.Printf(`"net/http"` + "\n")
181 g.Printf(`"io"` + "\n")
182 g.Printf(`"testing"` + "\n")
183 g.Printf(")\n")
184 g.Printf("\n")
185
186
187 g.Printf("func TestWrap(t *testing.T) {\n")
188 combinations := 1 << uint(len(subIfaces))
189 for i := 0; i < combinations; i++ {
190 fields := make([]string, 0, len(subIfaces))
191 fields = append(fields, "http.ResponseWriter")
192 expected := make([]bool, len(ifaces))
193 expected[0] = true
194 for j, iface := range subIfaces {
195 ok := i&(1<<uint(len(subIfaces)-j-1)) > 0
196 expected[j+1] = ok
197 if ok {
198 fields = append(fields, iface.Name)
199 }
200 }
201 g.Printf("// combination %d/%d\n", i+1, combinations)
202 g.Printf("{\n")
203 g.Printf(`t.Log("%s")`+"\n", strings.Join(fields, ", "))
204 g.Printf("inner := struct{\n%s\n}{}\n", strings.Join(fields, "\n"))
205 g.Printf("w := Wrap(inner, Hooks{})\n")
206 for i, iface := range ifaces {
207 g.Printf("if _, ok := w.(%s); ok != %t {\n", iface.Name, expected[i])
208 g.Printf("t.Error(\"unexpected interface\");\n")
209 g.Printf("}\n")
210 }
211 g.Printf(`
212 if w, ok := w.(Unwrapper); ok {
213 if w.Unwrap() != inner {
214 t.Error("w.Unwrap() failed")
215 }
216 } else {
217 t.Error("Unwrapper interface not implemented")
218 }`)
219 g.Printf("}\n")
220 g.Printf("\n")
221 }
222 g.Printf("}\n")
223 return &g
224 }
225
226 type Interfaces []*Interface
227
228 type Interface struct {
229 Name string
230 Funcs []*InterfaceFunc
231 }
232
233 type InterfaceFunc struct {
234 Name string
235 Args FuncArgs
236 Returns string
237 }
238
239 type FuncArgs []*FuncArg
240
241 func (fa FuncArgs) String() string {
242 args := make([]string, len(fa))
243 for i, a := range fa {
244 args[i] = a.Name + " " + a.Type
245 }
246 return strings.Join(args, ", ")
247 }
248
249 func (fa FuncArgs) Names() string {
250 args := make([]string, len(fa))
251 for i, a := range fa {
252 args[i] = a.Name
253 }
254 return strings.Join(args, ", ")
255 }
256
257 type FuncArg struct {
258 Name string
259 Type string
260 }
261
262 func (fn *InterfaceFunc) Type() string {
263 return fn.Name + "Func"
264 }
265
266 type Generator struct {
267 buf bytes.Buffer
268 }
269
270 func (g *Generator) Printf(s string, args ...interface{}) {
271 fmt.Fprintf(&g.buf, s, args...)
272 }
273
274 func (g *Generator) WriteFile(name string) error {
275 src, err := g.Format()
276 if err != nil {
277 return fmt.Errorf("format: %s: %s:\n\n%s\n", name, err, g.Bytes())
278 } else if err := ioutil.WriteFile(name, src, 0644); err != nil {
279 return err
280 }
281 return nil
282
283 }
284
285 func (g *Generator) MustWriteFile(name string) {
286 if err := g.WriteFile(name); err != nil {
287 fatalf("%s", err)
288 }
289 }
290
291 func (g *Generator) Bytes() []byte {
292 return g.buf.Bytes()
293 }
294
295 func (g *Generator) Format() ([]byte, error) {
296 return format.Source(g.Bytes())
297 }
298
299 func main() {
300 ifaces := Interfaces{
301 {
302 Name: "http.ResponseWriter",
303 Funcs: []*InterfaceFunc{
304 {"Header", nil, "http.Header"},
305 {"WriteHeader", FuncArgs{{"code", "int"}}, ""},
306 {"Write", FuncArgs{{"b", "[]byte"}}, "int, error"},
307 },
308 },
309 {
310 Name: "http.Flusher",
311 Funcs: []*InterfaceFunc{
312 {"Flush", nil, ""},
313 },
314 },
315 {
316 Name: "http.CloseNotifier",
317 Funcs: []*InterfaceFunc{
318 {"CloseNotify", nil, "<-chan bool"},
319 },
320 },
321 {
322 Name: "http.Hijacker",
323 Funcs: []*InterfaceFunc{
324 {"Hijack", nil, "net.Conn, *bufio.ReadWriter, error"},
325 },
326 },
327 {
328 Name: "io.ReaderFrom",
329 Funcs: []*InterfaceFunc{
330 {"ReadFrom", FuncArgs{{"src", "io.Reader"}}, "int64, error"},
331 },
332 },
333 }
334 builds := []Build{
335 {
336 Suffix: "lt_1.8",
337 Tags: "!go1.8",
338 Interfaces: ifaces,
339 },
340 {
341 Suffix: "gteq_1.8",
342 Tags: "go1.8",
343 Interfaces: append(ifaces, &Interface{
344 Name: "http.Pusher",
345 Funcs: []*InterfaceFunc{
346 {"Push", FuncArgs{
347 {"target", "string"},
348 {"opts", "*http.PushOptions"},
349 }, "error"},
350 },
351 }),
352 },
353 }
354 for _, build := range builds {
355 build.MustBuild()
356 }
357 }
358
359 func fatalf(s string, args ...interface{}) {
360 fmt.Fprintf(os.Stderr, s+"\n", args...)
361 os.Exit(1)
362 }
363
View as plain text