...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package cmux
16
17 import (
18 "errors"
19 "fmt"
20 "io"
21 "net"
22 "sync"
23 "time"
24 )
25
26
27 type Matcher func(io.Reader) bool
28
29
30 type MatchWriter func(io.Writer, io.Reader) bool
31
32
33
34 type ErrorHandler func(error) bool
35
36 var _ net.Error = ErrNotMatched{}
37
38
39
40 type ErrNotMatched struct {
41 c net.Conn
42 }
43
44 func (e ErrNotMatched) Error() string {
45 return fmt.Sprintf("mux: connection %v not matched by an matcher",
46 e.c.RemoteAddr())
47 }
48
49
50 func (e ErrNotMatched) Temporary() bool { return true }
51
52
53 func (e ErrNotMatched) Timeout() bool { return false }
54
55 type errListenerClosed string
56
57 func (e errListenerClosed) Error() string { return string(e) }
58 func (e errListenerClosed) Temporary() bool { return false }
59 func (e errListenerClosed) Timeout() bool { return false }
60
61
62
63 var ErrListenerClosed = errListenerClosed("mux: listener closed")
64
65
66 var ErrServerClosed = errors.New("mux: server closed")
67
68
69 var noTimeout time.Duration
70
71
72 func New(l net.Listener) CMux {
73 return &cMux{
74 root: l,
75 bufLen: 1024,
76 errh: func(_ error) bool { return true },
77 donec: make(chan struct{}),
78 readTimeout: noTimeout,
79 }
80 }
81
82
83 type CMux interface {
84
85
86
87
88 Match(...Matcher) net.Listener
89
90
91
92
93
94
95
96 MatchWithWriters(...MatchWriter) net.Listener
97
98
99 Serve() error
100
101 Close()
102
103 HandleError(ErrorHandler)
104
105 SetReadTimeout(time.Duration)
106 }
107
108 type matchersListener struct {
109 ss []MatchWriter
110 l muxListener
111 }
112
113 type cMux struct {
114 root net.Listener
115 bufLen int
116 errh ErrorHandler
117 sls []matchersListener
118 readTimeout time.Duration
119 donec chan struct{}
120 mu sync.Mutex
121 }
122
123 func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
124 mws := make([]MatchWriter, 0, len(matchers))
125 for _, m := range matchers {
126 cm := m
127 mws = append(mws, func(w io.Writer, r io.Reader) bool {
128 return cm(r)
129 })
130 }
131 return mws
132 }
133
134 func (m *cMux) Match(matchers ...Matcher) net.Listener {
135 mws := matchersToMatchWriters(matchers)
136 return m.MatchWithWriters(mws...)
137 }
138
139 func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener {
140 ml := muxListener{
141 Listener: m.root,
142 connc: make(chan net.Conn, m.bufLen),
143 donec: make(chan struct{}),
144 }
145 m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
146 return ml
147 }
148
149 func (m *cMux) SetReadTimeout(t time.Duration) {
150 m.readTimeout = t
151 }
152
153 func (m *cMux) Serve() error {
154 var wg sync.WaitGroup
155
156 defer func() {
157 m.closeDoneChans()
158 wg.Wait()
159
160 for _, sl := range m.sls {
161 close(sl.l.connc)
162
163 for c := range sl.l.connc {
164 _ = c.Close()
165 }
166 }
167 }()
168
169 for {
170 c, err := m.root.Accept()
171 if err != nil {
172 if !m.handleErr(err) {
173 return err
174 }
175 continue
176 }
177
178 wg.Add(1)
179 go m.serve(c, m.donec, &wg)
180 }
181 }
182
183 func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
184 defer wg.Done()
185
186 muc := newMuxConn(c)
187 if m.readTimeout > noTimeout {
188 _ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
189 }
190 for _, sl := range m.sls {
191 for _, s := range sl.ss {
192 matched := s(muc.Conn, muc.startSniffing())
193 if matched {
194 muc.doneSniffing()
195 if m.readTimeout > noTimeout {
196 _ = c.SetReadDeadline(time.Time{})
197 }
198 select {
199 case sl.l.connc <- muc:
200 case <-donec:
201 _ = c.Close()
202 }
203 return
204 }
205 }
206 }
207
208 _ = c.Close()
209 err := ErrNotMatched{c: c}
210 if !m.handleErr(err) {
211 _ = m.root.Close()
212 }
213 }
214
215 func (m *cMux) Close() {
216 m.closeDoneChans()
217 }
218
219 func (m *cMux) closeDoneChans() {
220 m.mu.Lock()
221 defer m.mu.Unlock()
222
223 select {
224 case <-m.donec:
225
226 default:
227 close(m.donec)
228 }
229 for _, sl := range m.sls {
230 select {
231 case <-sl.l.donec:
232
233 default:
234 close(sl.l.donec)
235 }
236 }
237 }
238
239 func (m *cMux) HandleError(h ErrorHandler) {
240 m.errh = h
241 }
242
243 func (m *cMux) handleErr(err error) bool {
244 if !m.errh(err) {
245 return false
246 }
247
248 if ne, ok := err.(net.Error); ok {
249 return ne.Temporary()
250 }
251
252 return false
253 }
254
255 type muxListener struct {
256 net.Listener
257 connc chan net.Conn
258 donec chan struct{}
259 }
260
261 func (l muxListener) Accept() (net.Conn, error) {
262 select {
263 case c, ok := <-l.connc:
264 if !ok {
265 return nil, ErrListenerClosed
266 }
267 return c, nil
268 case <-l.donec:
269 return nil, ErrServerClosed
270 }
271 }
272
273
274 type MuxConn struct {
275 net.Conn
276 buf bufferedReader
277 }
278
279 func newMuxConn(c net.Conn) *MuxConn {
280 return &MuxConn{
281 Conn: c,
282 buf: bufferedReader{source: c},
283 }
284 }
285
286
287
288
289
290
291
292
293
294
295
296 func (m *MuxConn) Read(p []byte) (int, error) {
297 return m.buf.Read(p)
298 }
299
300 func (m *MuxConn) startSniffing() io.Reader {
301 m.buf.reset(true)
302 return &m.buf
303 }
304
305 func (m *MuxConn) doneSniffing() {
306 m.buf.reset(false)
307 }
308
View as plain text