...
1 package middleware
2
3
4
5
6 import (
7 "bufio"
8 "io"
9 "net"
10 "net/http"
11 )
12
13
14
15 func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter {
16 _, fl := w.(http.Flusher)
17
18 bw := basicWriter{ResponseWriter: w}
19
20 if protoMajor == 2 {
21 _, ps := w.(http.Pusher)
22 if fl && ps {
23 return &http2FancyWriter{bw}
24 }
25 } else {
26 _, hj := w.(http.Hijacker)
27 _, rf := w.(io.ReaderFrom)
28 if fl && hj && rf {
29 return &httpFancyWriter{bw}
30 }
31 }
32 if fl {
33 return &flushWriter{bw}
34 }
35
36 return &bw
37 }
38
39
40
41 type WrapResponseWriter interface {
42 http.ResponseWriter
43
44
45 Status() int
46
47 BytesWritten() int
48
49
50
51
52
53
54 Tee(io.Writer)
55
56 Unwrap() http.ResponseWriter
57 }
58
59
60
61 type basicWriter struct {
62 http.ResponseWriter
63 wroteHeader bool
64 code int
65 bytes int
66 tee io.Writer
67 }
68
69 func (b *basicWriter) WriteHeader(code int) {
70 if !b.wroteHeader {
71 b.code = code
72 b.wroteHeader = true
73 b.ResponseWriter.WriteHeader(code)
74 }
75 }
76
77 func (b *basicWriter) Write(buf []byte) (int, error) {
78 b.maybeWriteHeader()
79 n, err := b.ResponseWriter.Write(buf)
80 if b.tee != nil {
81 _, err2 := b.tee.Write(buf[:n])
82
83 if err == nil {
84 err = err2
85 }
86 }
87 b.bytes += n
88 return n, err
89 }
90
91 func (b *basicWriter) maybeWriteHeader() {
92 if !b.wroteHeader {
93 b.WriteHeader(http.StatusOK)
94 }
95 }
96
97 func (b *basicWriter) Status() int {
98 return b.code
99 }
100
101 func (b *basicWriter) BytesWritten() int {
102 return b.bytes
103 }
104
105 func (b *basicWriter) Tee(w io.Writer) {
106 b.tee = w
107 }
108
109 func (b *basicWriter) Unwrap() http.ResponseWriter {
110 return b.ResponseWriter
111 }
112
113 type flushWriter struct {
114 basicWriter
115 }
116
117 func (f *flushWriter) Flush() {
118 f.wroteHeader = true
119 fl := f.basicWriter.ResponseWriter.(http.Flusher)
120 fl.Flush()
121 }
122
123 var _ http.Flusher = &flushWriter{}
124
125
126
127
128
129 type httpFancyWriter struct {
130 basicWriter
131 }
132
133 func (f *httpFancyWriter) Flush() {
134 f.wroteHeader = true
135 fl := f.basicWriter.ResponseWriter.(http.Flusher)
136 fl.Flush()
137 }
138
139 func (f *httpFancyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
140 hj := f.basicWriter.ResponseWriter.(http.Hijacker)
141 return hj.Hijack()
142 }
143
144 func (f *http2FancyWriter) Push(target string, opts *http.PushOptions) error {
145 return f.basicWriter.ResponseWriter.(http.Pusher).Push(target, opts)
146 }
147
148 func (f *httpFancyWriter) ReadFrom(r io.Reader) (int64, error) {
149 if f.basicWriter.tee != nil {
150 n, err := io.Copy(&f.basicWriter, r)
151 f.basicWriter.bytes += int(n)
152 return n, err
153 }
154 rf := f.basicWriter.ResponseWriter.(io.ReaderFrom)
155 f.basicWriter.maybeWriteHeader()
156 n, err := rf.ReadFrom(r)
157 f.basicWriter.bytes += int(n)
158 return n, err
159 }
160
161 var _ http.Flusher = &httpFancyWriter{}
162 var _ http.Hijacker = &httpFancyWriter{}
163 var _ http.Pusher = &http2FancyWriter{}
164 var _ io.ReaderFrom = &httpFancyWriter{}
165
166
167
168
169
170 type http2FancyWriter struct {
171 basicWriter
172 }
173
174 func (f *http2FancyWriter) Flush() {
175 f.wroteHeader = true
176 fl := f.basicWriter.ResponseWriter.(http.Flusher)
177 fl.Flush()
178 }
179
180 var _ http.Flusher = &http2FancyWriter{}
181
View as plain text