...
1 package mutil
2
3 import (
4 "bufio"
5 "io"
6 "net"
7 "net/http"
8 )
9
10
11
12 type WriterProxy interface {
13 http.ResponseWriter
14
15
16 Status() int
17
18 BytesWritten() int
19
20
21
22
23
24
25 Tee(io.Writer)
26
27 Unwrap() http.ResponseWriter
28 }
29
30
31
32 func WrapWriter(w http.ResponseWriter) WriterProxy {
33 _, cn := w.(http.CloseNotifier)
34 _, fl := w.(http.Flusher)
35 _, hj := w.(http.Hijacker)
36 _, rf := w.(io.ReaderFrom)
37
38 bw := basicWriter{ResponseWriter: w}
39 if cn && fl && hj && rf {
40 return &fancyWriter{bw}
41 }
42 if fl {
43 return &flushWriter{bw}
44 }
45 return &bw
46 }
47
48
49
50 type basicWriter struct {
51 http.ResponseWriter
52 wroteHeader bool
53 code int
54 bytes int
55 tee io.Writer
56 }
57
58 func (b *basicWriter) WriteHeader(code int) {
59 if !b.wroteHeader {
60 b.code = code
61 b.wroteHeader = true
62 b.ResponseWriter.WriteHeader(code)
63 }
64 }
65
66 func (b *basicWriter) Write(buf []byte) (int, error) {
67 b.WriteHeader(http.StatusOK)
68 n, err := b.ResponseWriter.Write(buf)
69 if b.tee != nil {
70 _, err2 := b.tee.Write(buf[:n])
71
72 if err == nil {
73 err = err2
74 }
75 }
76 b.bytes += n
77 return n, err
78 }
79
80 func (b *basicWriter) maybeWriteHeader() {
81 if !b.wroteHeader {
82 b.WriteHeader(http.StatusOK)
83 }
84 }
85
86 func (b *basicWriter) Status() int {
87 return b.code
88 }
89
90 func (b *basicWriter) BytesWritten() int {
91 return b.bytes
92 }
93
94 func (b *basicWriter) Tee(w io.Writer) {
95 b.tee = w
96 }
97
98 func (b *basicWriter) Unwrap() http.ResponseWriter {
99 return b.ResponseWriter
100 }
101
102
103
104
105
106 type fancyWriter struct {
107 basicWriter
108 }
109
110 func (f *fancyWriter) CloseNotify() <-chan bool {
111 cn := f.basicWriter.ResponseWriter.(http.CloseNotifier)
112 return cn.CloseNotify()
113 }
114
115 func (f *fancyWriter) Flush() {
116 fl := f.basicWriter.ResponseWriter.(http.Flusher)
117 fl.Flush()
118 }
119
120 func (f *fancyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
121 hj := f.basicWriter.ResponseWriter.(http.Hijacker)
122 return hj.Hijack()
123 }
124
125 func (f *fancyWriter) ReadFrom(r io.Reader) (int64, error) {
126 if f.basicWriter.tee != nil {
127 n, err := io.Copy(&f.basicWriter, r)
128 f.bytes += int(n)
129 return n, err
130 }
131 rf := f.basicWriter.ResponseWriter.(io.ReaderFrom)
132 f.basicWriter.maybeWriteHeader()
133
134 n, err := rf.ReadFrom(r)
135 f.bytes += int(n)
136 return n, err
137 }
138
139 type flushWriter struct {
140 basicWriter
141 }
142
143 func (f *flushWriter) Flush() {
144 fl := f.basicWriter.ResponseWriter.(http.Flusher)
145 fl.Flush()
146 }
147
148 var (
149 _ http.CloseNotifier = &fancyWriter{}
150 _ http.Flusher = &fancyWriter{}
151 _ http.Hijacker = &fancyWriter{}
152 _ io.ReaderFrom = &fancyWriter{}
153 _ http.Flusher = &flushWriter{}
154 )
155
View as plain text