...

Source file src/github.com/soheilhy/cmux/cmux.go

Documentation: github.com/soheilhy/cmux

     1  // Copyright 2016 The CMux Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
    12  // implied. See the License for the specific language governing
    13  // permissions and limitations under the License.
    14  
    15  package cmux
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  	"net"
    22  	"sync"
    23  	"time"
    24  )
    25  
    26  // Matcher matches a connection based on its content.
    27  type Matcher func(io.Reader) bool
    28  
    29  // MatchWriter is a match that can also write response (say to do handshake).
    30  type MatchWriter func(io.Writer, io.Reader) bool
    31  
    32  // ErrorHandler handles an error and returns whether
    33  // the mux should continue serving the listener.
    34  type ErrorHandler func(error) bool
    35  
    36  var _ net.Error = ErrNotMatched{}
    37  
    38  // ErrNotMatched is returned whenever a connection is not matched by any of
    39  // the matchers registered in the multiplexer.
    40  type ErrNotMatched struct {
    41  	c net.Conn
    42  }
    43  
    44  func (e ErrNotMatched) Error() string {
    45  	return fmt.Sprintf("mux: connection %v not matched by an matcher",
    46  		e.c.RemoteAddr())
    47  }
    48  
    49  // Temporary implements the net.Error interface.
    50  func (e ErrNotMatched) Temporary() bool { return true }
    51  
    52  // Timeout implements the net.Error interface.
    53  func (e ErrNotMatched) Timeout() bool { return false }
    54  
    55  type errListenerClosed string
    56  
    57  func (e errListenerClosed) Error() string   { return string(e) }
    58  func (e errListenerClosed) Temporary() bool { return false }
    59  func (e errListenerClosed) Timeout() bool   { return false }
    60  
    61  // ErrListenerClosed is returned from muxListener.Accept when the underlying
    62  // listener is closed.
    63  var ErrListenerClosed = errListenerClosed("mux: listener closed")
    64  
    65  // ErrServerClosed is returned from muxListener.Accept when mux server is closed.
    66  var ErrServerClosed = errors.New("mux: server closed")
    67  
    68  // for readability of readTimeout
    69  var noTimeout time.Duration
    70  
    71  // New instantiates a new connection multiplexer.
    72  func New(l net.Listener) CMux {
    73  	return &cMux{
    74  		root:        l,
    75  		bufLen:      1024,
    76  		errh:        func(_ error) bool { return true },
    77  		donec:       make(chan struct{}),
    78  		readTimeout: noTimeout,
    79  	}
    80  }
    81  
    82  // CMux is a multiplexer for network connections.
    83  type CMux interface {
    84  	// Match returns a net.Listener that sees (i.e., accepts) only
    85  	// the connections matched by at least one of the matcher.
    86  	//
    87  	// The order used to call Match determines the priority of matchers.
    88  	Match(...Matcher) net.Listener
    89  	// MatchWithWriters returns a net.Listener that accepts only the
    90  	// connections that matched by at least of the matcher writers.
    91  	//
    92  	// Prefer Matchers over MatchWriters, since the latter can write on the
    93  	// connection before the actual handler.
    94  	//
    95  	// The order used to call Match determines the priority of matchers.
    96  	MatchWithWriters(...MatchWriter) net.Listener
    97  	// Serve starts multiplexing the listener. Serve blocks and perhaps
    98  	// should be invoked concurrently within a go routine.
    99  	Serve() error
   100  	// Closes cmux server and stops accepting any connections on listener
   101  	Close()
   102  	// HandleError registers an error handler that handles listener errors.
   103  	HandleError(ErrorHandler)
   104  	// sets a timeout for the read of matchers
   105  	SetReadTimeout(time.Duration)
   106  }
   107  
   108  type matchersListener struct {
   109  	ss []MatchWriter
   110  	l  muxListener
   111  }
   112  
   113  type cMux struct {
   114  	root        net.Listener
   115  	bufLen      int
   116  	errh        ErrorHandler
   117  	sls         []matchersListener
   118  	readTimeout time.Duration
   119  	donec       chan struct{}
   120  	mu          sync.Mutex
   121  }
   122  
   123  func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
   124  	mws := make([]MatchWriter, 0, len(matchers))
   125  	for _, m := range matchers {
   126  		cm := m
   127  		mws = append(mws, func(w io.Writer, r io.Reader) bool {
   128  			return cm(r)
   129  		})
   130  	}
   131  	return mws
   132  }
   133  
   134  func (m *cMux) Match(matchers ...Matcher) net.Listener {
   135  	mws := matchersToMatchWriters(matchers)
   136  	return m.MatchWithWriters(mws...)
   137  }
   138  
   139  func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener {
   140  	ml := muxListener{
   141  		Listener: m.root,
   142  		connc:    make(chan net.Conn, m.bufLen),
   143  		donec:    make(chan struct{}),
   144  	}
   145  	m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
   146  	return ml
   147  }
   148  
   149  func (m *cMux) SetReadTimeout(t time.Duration) {
   150  	m.readTimeout = t
   151  }
   152  
   153  func (m *cMux) Serve() error {
   154  	var wg sync.WaitGroup
   155  
   156  	defer func() {
   157  		m.closeDoneChans()
   158  		wg.Wait()
   159  
   160  		for _, sl := range m.sls {
   161  			close(sl.l.connc)
   162  			// Drain the connections enqueued for the listener.
   163  			for c := range sl.l.connc {
   164  				_ = c.Close()
   165  			}
   166  		}
   167  	}()
   168  
   169  	for {
   170  		c, err := m.root.Accept()
   171  		if err != nil {
   172  			if !m.handleErr(err) {
   173  				return err
   174  			}
   175  			continue
   176  		}
   177  
   178  		wg.Add(1)
   179  		go m.serve(c, m.donec, &wg)
   180  	}
   181  }
   182  
   183  func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
   184  	defer wg.Done()
   185  
   186  	muc := newMuxConn(c)
   187  	if m.readTimeout > noTimeout {
   188  		_ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
   189  	}
   190  	for _, sl := range m.sls {
   191  		for _, s := range sl.ss {
   192  			matched := s(muc.Conn, muc.startSniffing())
   193  			if matched {
   194  				muc.doneSniffing()
   195  				if m.readTimeout > noTimeout {
   196  					_ = c.SetReadDeadline(time.Time{})
   197  				}
   198  				select {
   199  				case sl.l.connc <- muc:
   200  				case <-donec:
   201  					_ = c.Close()
   202  				}
   203  				return
   204  			}
   205  		}
   206  	}
   207  
   208  	_ = c.Close()
   209  	err := ErrNotMatched{c: c}
   210  	if !m.handleErr(err) {
   211  		_ = m.root.Close()
   212  	}
   213  }
   214  
   215  func (m *cMux) Close() {
   216  	m.closeDoneChans()
   217  }
   218  
   219  func (m *cMux) closeDoneChans() {
   220  	m.mu.Lock()
   221  	defer m.mu.Unlock()
   222  
   223  	select {
   224  	case <-m.donec:
   225  		// Already closed. Don't close again
   226  	default:
   227  		close(m.donec)
   228  	}
   229  	for _, sl := range m.sls {
   230  		select {
   231  		case <-sl.l.donec:
   232  			// Already closed. Don't close again
   233  		default:
   234  			close(sl.l.donec)
   235  		}
   236  	}
   237  }
   238  
   239  func (m *cMux) HandleError(h ErrorHandler) {
   240  	m.errh = h
   241  }
   242  
   243  func (m *cMux) handleErr(err error) bool {
   244  	if !m.errh(err) {
   245  		return false
   246  	}
   247  
   248  	if ne, ok := err.(net.Error); ok {
   249  		return ne.Temporary()
   250  	}
   251  
   252  	return false
   253  }
   254  
   255  type muxListener struct {
   256  	net.Listener
   257  	connc chan net.Conn
   258  	donec chan struct{}
   259  }
   260  
   261  func (l muxListener) Accept() (net.Conn, error) {
   262  	select {
   263  	case c, ok := <-l.connc:
   264  		if !ok {
   265  			return nil, ErrListenerClosed
   266  		}
   267  		return c, nil
   268  	case <-l.donec:
   269  		return nil, ErrServerClosed
   270  	}
   271  }
   272  
   273  // MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
   274  type MuxConn struct {
   275  	net.Conn
   276  	buf bufferedReader
   277  }
   278  
   279  func newMuxConn(c net.Conn) *MuxConn {
   280  	return &MuxConn{
   281  		Conn: c,
   282  		buf:  bufferedReader{source: c},
   283  	}
   284  }
   285  
   286  // From the io.Reader documentation:
   287  //
   288  // When Read encounters an error or end-of-file condition after
   289  // successfully reading n > 0 bytes, it returns the number of
   290  // bytes read.  It may return the (non-nil) error from the same call
   291  // or return the error (and n == 0) from a subsequent call.
   292  // An instance of this general case is that a Reader returning
   293  // a non-zero number of bytes at the end of the input stream may
   294  // return either err == EOF or err == nil.  The next Read should
   295  // return 0, EOF.
   296  func (m *MuxConn) Read(p []byte) (int, error) {
   297  	return m.buf.Read(p)
   298  }
   299  
   300  func (m *MuxConn) startSniffing() io.Reader {
   301  	m.buf.reset(true)
   302  	return &m.buf
   303  }
   304  
   305  func (m *MuxConn) doneSniffing() {
   306  	m.buf.reset(false)
   307  }
   308  

View as plain text