...

Source file src/github.com/moby/spdystream/connection.go

Documentation: github.com/moby/spdystream

     1  /*
     2     Copyright 2014-2021 Docker Inc.
     3  
     4     Licensed under the Apache License, Version 2.0 (the "License");
     5     you may not use this file except in compliance with the License.
     6     You may obtain a copy of the License at
     7  
     8         http://www.apache.org/licenses/LICENSE-2.0
     9  
    10     Unless required by applicable law or agreed to in writing, software
    11     distributed under the License is distributed on an "AS IS" BASIS,
    12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13     See the License for the specific language governing permissions and
    14     limitations under the License.
    15  */
    16  
    17  package spdystream
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"net"
    24  	"net/http"
    25  	"sync"
    26  	"time"
    27  
    28  	"github.com/moby/spdystream/spdy"
    29  )
    30  
    31  var (
    32  	ErrInvalidStreamId   = errors.New("Invalid stream id")
    33  	ErrTimeout           = errors.New("Timeout occurred")
    34  	ErrReset             = errors.New("Stream reset")
    35  	ErrWriteClosedStream = errors.New("Write on closed stream")
    36  )
    37  
    38  const (
    39  	FRAME_WORKERS = 5
    40  	QUEUE_SIZE    = 50
    41  )
    42  
    43  type StreamHandler func(stream *Stream)
    44  
    45  type AuthHandler func(header http.Header, slot uint8, parent uint32) bool
    46  
    47  type idleAwareFramer struct {
    48  	f              *spdy.Framer
    49  	conn           *Connection
    50  	writeLock      sync.Mutex
    51  	resetChan      chan struct{}
    52  	setTimeoutLock sync.Mutex
    53  	setTimeoutChan chan time.Duration
    54  	timeout        time.Duration
    55  }
    56  
    57  func newIdleAwareFramer(framer *spdy.Framer) *idleAwareFramer {
    58  	iaf := &idleAwareFramer{
    59  		f:         framer,
    60  		resetChan: make(chan struct{}, 2),
    61  		// setTimeoutChan needs to be buffered to avoid deadlocks when calling setIdleTimeout at about
    62  		// the same time the connection is being closed
    63  		setTimeoutChan: make(chan time.Duration, 1),
    64  	}
    65  	return iaf
    66  }
    67  
    68  func (i *idleAwareFramer) monitor() {
    69  	var (
    70  		timer          *time.Timer
    71  		expired        <-chan time.Time
    72  		resetChan      = i.resetChan
    73  		setTimeoutChan = i.setTimeoutChan
    74  	)
    75  Loop:
    76  	for {
    77  		select {
    78  		case timeout := <-i.setTimeoutChan:
    79  			i.timeout = timeout
    80  			if timeout == 0 {
    81  				if timer != nil {
    82  					timer.Stop()
    83  				}
    84  			} else {
    85  				if timer == nil {
    86  					timer = time.NewTimer(timeout)
    87  					expired = timer.C
    88  				} else {
    89  					timer.Reset(timeout)
    90  				}
    91  			}
    92  		case <-resetChan:
    93  			if timer != nil && i.timeout > 0 {
    94  				timer.Reset(i.timeout)
    95  			}
    96  		case <-expired:
    97  			i.conn.streamCond.L.Lock()
    98  			streams := i.conn.streams
    99  			i.conn.streams = make(map[spdy.StreamId]*Stream)
   100  			i.conn.streamCond.Broadcast()
   101  			i.conn.streamCond.L.Unlock()
   102  			go func() {
   103  				for _, stream := range streams {
   104  					stream.resetStream()
   105  				}
   106  				i.conn.Close()
   107  			}()
   108  		case <-i.conn.closeChan:
   109  			if timer != nil {
   110  				timer.Stop()
   111  			}
   112  
   113  			// Start a goroutine to drain resetChan. This is needed because we've seen
   114  			// some unit tests with large numbers of goroutines get into a situation
   115  			// where resetChan fills up, at least 1 call to Write() is still trying to
   116  			// send to resetChan, the connection gets closed, and this case statement
   117  			// attempts to grab the write lock that Write() already has, causing a
   118  			// deadlock.
   119  			//
   120  			// See https://github.com/moby/spdystream/issues/49 for more details.
   121  			go func() {
   122  				for range resetChan {
   123  				}
   124  			}()
   125  
   126  			go func() {
   127  				for range setTimeoutChan {
   128  				}
   129  			}()
   130  
   131  			i.writeLock.Lock()
   132  			close(resetChan)
   133  			i.resetChan = nil
   134  			i.writeLock.Unlock()
   135  
   136  			i.setTimeoutLock.Lock()
   137  			close(i.setTimeoutChan)
   138  			i.setTimeoutChan = nil
   139  			i.setTimeoutLock.Unlock()
   140  
   141  			break Loop
   142  		}
   143  	}
   144  
   145  	// Drain resetChan
   146  	for range resetChan {
   147  	}
   148  }
   149  
   150  func (i *idleAwareFramer) WriteFrame(frame spdy.Frame) error {
   151  	i.writeLock.Lock()
   152  	defer i.writeLock.Unlock()
   153  	if i.resetChan == nil {
   154  		return io.EOF
   155  	}
   156  	err := i.f.WriteFrame(frame)
   157  	if err != nil {
   158  		return err
   159  	}
   160  
   161  	i.resetChan <- struct{}{}
   162  
   163  	return nil
   164  }
   165  
   166  func (i *idleAwareFramer) ReadFrame() (spdy.Frame, error) {
   167  	frame, err := i.f.ReadFrame()
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	// resetChan should never be closed since it is only closed
   173  	// when the connection has closed its closeChan. This closure
   174  	// only occurs after all Reads have finished
   175  	// TODO (dmcgowan): refactor relationship into connection
   176  	i.resetChan <- struct{}{}
   177  
   178  	return frame, nil
   179  }
   180  
   181  func (i *idleAwareFramer) setIdleTimeout(timeout time.Duration) {
   182  	i.setTimeoutLock.Lock()
   183  	defer i.setTimeoutLock.Unlock()
   184  
   185  	if i.setTimeoutChan == nil {
   186  		return
   187  	}
   188  
   189  	i.setTimeoutChan <- timeout
   190  }
   191  
   192  type Connection struct {
   193  	conn   net.Conn
   194  	framer *idleAwareFramer
   195  
   196  	closeChan      chan bool
   197  	goneAway       bool
   198  	lastStreamChan chan<- *Stream
   199  	goAwayTimeout  time.Duration
   200  	closeTimeout   time.Duration
   201  
   202  	streamLock *sync.RWMutex
   203  	streamCond *sync.Cond
   204  	streams    map[spdy.StreamId]*Stream
   205  
   206  	nextIdLock       sync.Mutex
   207  	receiveIdLock    sync.Mutex
   208  	nextStreamId     spdy.StreamId
   209  	receivedStreamId spdy.StreamId
   210  
   211  	pingIdLock sync.Mutex
   212  	pingId     uint32
   213  	pingChans  map[uint32]chan error
   214  
   215  	shutdownLock sync.Mutex
   216  	shutdownChan chan error
   217  	hasShutdown  bool
   218  
   219  	// for testing https://github.com/moby/spdystream/pull/56
   220  	dataFrameHandler func(*spdy.DataFrame) error
   221  }
   222  
   223  // NewConnection creates a new spdy connection from an existing
   224  // network connection.
   225  func NewConnection(conn net.Conn, server bool) (*Connection, error) {
   226  	framer, framerErr := spdy.NewFramer(conn, conn)
   227  	if framerErr != nil {
   228  		return nil, framerErr
   229  	}
   230  	idleAwareFramer := newIdleAwareFramer(framer)
   231  	var sid spdy.StreamId
   232  	var rid spdy.StreamId
   233  	var pid uint32
   234  	if server {
   235  		sid = 2
   236  		rid = 1
   237  		pid = 2
   238  	} else {
   239  		sid = 1
   240  		rid = 2
   241  		pid = 1
   242  	}
   243  
   244  	streamLock := new(sync.RWMutex)
   245  	streamCond := sync.NewCond(streamLock)
   246  
   247  	session := &Connection{
   248  		conn:   conn,
   249  		framer: idleAwareFramer,
   250  
   251  		closeChan:     make(chan bool),
   252  		goAwayTimeout: time.Duration(0),
   253  		closeTimeout:  time.Duration(0),
   254  
   255  		streamLock:       streamLock,
   256  		streamCond:       streamCond,
   257  		streams:          make(map[spdy.StreamId]*Stream),
   258  		nextStreamId:     sid,
   259  		receivedStreamId: rid,
   260  
   261  		pingId:    pid,
   262  		pingChans: make(map[uint32]chan error),
   263  
   264  		shutdownChan: make(chan error),
   265  	}
   266  	session.dataFrameHandler = session.handleDataFrame
   267  	idleAwareFramer.conn = session
   268  	go idleAwareFramer.monitor()
   269  
   270  	return session, nil
   271  }
   272  
   273  // Ping sends a ping frame across the connection and
   274  // returns the response time
   275  func (s *Connection) Ping() (time.Duration, error) {
   276  	pid := s.pingId
   277  	s.pingIdLock.Lock()
   278  	if s.pingId > 0x7ffffffe {
   279  		s.pingId = s.pingId - 0x7ffffffe
   280  	} else {
   281  		s.pingId = s.pingId + 2
   282  	}
   283  	s.pingIdLock.Unlock()
   284  	pingChan := make(chan error)
   285  	s.pingChans[pid] = pingChan
   286  	defer delete(s.pingChans, pid)
   287  
   288  	frame := &spdy.PingFrame{Id: pid}
   289  	startTime := time.Now()
   290  	writeErr := s.framer.WriteFrame(frame)
   291  	if writeErr != nil {
   292  		return time.Duration(0), writeErr
   293  	}
   294  	select {
   295  	case <-s.closeChan:
   296  		return time.Duration(0), errors.New("connection closed")
   297  	case err, ok := <-pingChan:
   298  		if ok && err != nil {
   299  			return time.Duration(0), err
   300  		}
   301  		break
   302  	}
   303  	return time.Since(startTime), nil
   304  }
   305  
   306  // Serve handles frames sent from the server, including reply frames
   307  // which are needed to fully initiate connections.  Both clients and servers
   308  // should call Serve in a separate goroutine before creating streams.
   309  func (s *Connection) Serve(newHandler StreamHandler) {
   310  	// use a WaitGroup to wait for all frames to be drained after receiving
   311  	// go-away.
   312  	var wg sync.WaitGroup
   313  
   314  	// Parition queues to ensure stream frames are handled
   315  	// by the same worker, ensuring order is maintained
   316  	frameQueues := make([]*PriorityFrameQueue, FRAME_WORKERS)
   317  	for i := 0; i < FRAME_WORKERS; i++ {
   318  		frameQueues[i] = NewPriorityFrameQueue(QUEUE_SIZE)
   319  
   320  		// Ensure frame queue is drained when connection is closed
   321  		go func(frameQueue *PriorityFrameQueue) {
   322  			<-s.closeChan
   323  			frameQueue.Drain()
   324  		}(frameQueues[i])
   325  
   326  		wg.Add(1)
   327  		go func(frameQueue *PriorityFrameQueue) {
   328  			// let the WaitGroup know this worker is done
   329  			defer wg.Done()
   330  
   331  			s.frameHandler(frameQueue, newHandler)
   332  		}(frameQueues[i])
   333  	}
   334  
   335  	var (
   336  		partitionRoundRobin int
   337  		goAwayFrame         *spdy.GoAwayFrame
   338  	)
   339  Loop:
   340  	for {
   341  		readFrame, err := s.framer.ReadFrame()
   342  		if err != nil {
   343  			if err != io.EOF {
   344  				debugMessage("frame read error: %s", err)
   345  			} else {
   346  				debugMessage("(%p) EOF received", s)
   347  			}
   348  			break
   349  		}
   350  		var priority uint8
   351  		var partition int
   352  		switch frame := readFrame.(type) {
   353  		case *spdy.SynStreamFrame:
   354  			if s.checkStreamFrame(frame) {
   355  				priority = frame.Priority
   356  				partition = int(frame.StreamId % FRAME_WORKERS)
   357  				debugMessage("(%p) Add stream frame: %d ", s, frame.StreamId)
   358  				s.addStreamFrame(frame)
   359  			} else {
   360  				debugMessage("(%p) Rejected stream frame: %d ", s, frame.StreamId)
   361  				continue
   362  			}
   363  		case *spdy.SynReplyFrame:
   364  			priority = s.getStreamPriority(frame.StreamId)
   365  			partition = int(frame.StreamId % FRAME_WORKERS)
   366  		case *spdy.DataFrame:
   367  			priority = s.getStreamPriority(frame.StreamId)
   368  			partition = int(frame.StreamId % FRAME_WORKERS)
   369  		case *spdy.RstStreamFrame:
   370  			priority = s.getStreamPriority(frame.StreamId)
   371  			partition = int(frame.StreamId % FRAME_WORKERS)
   372  		case *spdy.HeadersFrame:
   373  			priority = s.getStreamPriority(frame.StreamId)
   374  			partition = int(frame.StreamId % FRAME_WORKERS)
   375  		case *spdy.PingFrame:
   376  			priority = 0
   377  			partition = partitionRoundRobin
   378  			partitionRoundRobin = (partitionRoundRobin + 1) % FRAME_WORKERS
   379  		case *spdy.GoAwayFrame:
   380  			// hold on to the go away frame and exit the loop
   381  			goAwayFrame = frame
   382  			break Loop
   383  		default:
   384  			priority = 7
   385  			partition = partitionRoundRobin
   386  			partitionRoundRobin = (partitionRoundRobin + 1) % FRAME_WORKERS
   387  		}
   388  		frameQueues[partition].Push(readFrame, priority)
   389  	}
   390  	close(s.closeChan)
   391  
   392  	// wait for all frame handler workers to indicate they've drained their queues
   393  	// before handling the go away frame
   394  	wg.Wait()
   395  
   396  	if goAwayFrame != nil {
   397  		s.handleGoAwayFrame(goAwayFrame)
   398  	}
   399  
   400  	// now it's safe to close remote channels and empty s.streams
   401  	s.streamCond.L.Lock()
   402  	// notify streams that they're now closed, which will
   403  	// unblock any stream Read() calls
   404  	for _, stream := range s.streams {
   405  		stream.closeRemoteChannels()
   406  	}
   407  	s.streams = make(map[spdy.StreamId]*Stream)
   408  	s.streamCond.Broadcast()
   409  	s.streamCond.L.Unlock()
   410  }
   411  
   412  func (s *Connection) frameHandler(frameQueue *PriorityFrameQueue, newHandler StreamHandler) {
   413  	for {
   414  		popFrame := frameQueue.Pop()
   415  		if popFrame == nil {
   416  			return
   417  		}
   418  
   419  		var frameErr error
   420  		switch frame := popFrame.(type) {
   421  		case *spdy.SynStreamFrame:
   422  			frameErr = s.handleStreamFrame(frame, newHandler)
   423  		case *spdy.SynReplyFrame:
   424  			frameErr = s.handleReplyFrame(frame)
   425  		case *spdy.DataFrame:
   426  			frameErr = s.dataFrameHandler(frame)
   427  		case *spdy.RstStreamFrame:
   428  			frameErr = s.handleResetFrame(frame)
   429  		case *spdy.HeadersFrame:
   430  			frameErr = s.handleHeaderFrame(frame)
   431  		case *spdy.PingFrame:
   432  			frameErr = s.handlePingFrame(frame)
   433  		case *spdy.GoAwayFrame:
   434  			frameErr = s.handleGoAwayFrame(frame)
   435  		default:
   436  			frameErr = fmt.Errorf("unhandled frame type: %T", frame)
   437  		}
   438  
   439  		if frameErr != nil {
   440  			debugMessage("frame handling error: %s", frameErr)
   441  		}
   442  	}
   443  }
   444  
   445  func (s *Connection) getStreamPriority(streamId spdy.StreamId) uint8 {
   446  	stream, streamOk := s.getStream(streamId)
   447  	if !streamOk {
   448  		return 7
   449  	}
   450  	return stream.priority
   451  }
   452  
   453  func (s *Connection) addStreamFrame(frame *spdy.SynStreamFrame) {
   454  	var parent *Stream
   455  	if frame.AssociatedToStreamId != spdy.StreamId(0) {
   456  		parent, _ = s.getStream(frame.AssociatedToStreamId)
   457  	}
   458  
   459  	stream := &Stream{
   460  		streamId:   frame.StreamId,
   461  		parent:     parent,
   462  		conn:       s,
   463  		startChan:  make(chan error),
   464  		headers:    frame.Headers,
   465  		finished:   (frame.CFHeader.Flags & spdy.ControlFlagUnidirectional) != 0x00,
   466  		replyCond:  sync.NewCond(new(sync.Mutex)),
   467  		dataChan:   make(chan []byte),
   468  		headerChan: make(chan http.Header),
   469  		closeChan:  make(chan bool),
   470  		priority:   frame.Priority,
   471  	}
   472  	if frame.CFHeader.Flags&spdy.ControlFlagFin != 0x00 {
   473  		stream.closeRemoteChannels()
   474  	}
   475  
   476  	s.addStream(stream)
   477  }
   478  
   479  // checkStreamFrame checks to see if a stream frame is allowed.
   480  // If the stream is invalid, then a reset frame with protocol error
   481  // will be returned.
   482  func (s *Connection) checkStreamFrame(frame *spdy.SynStreamFrame) bool {
   483  	s.receiveIdLock.Lock()
   484  	defer s.receiveIdLock.Unlock()
   485  	if s.goneAway {
   486  		return false
   487  	}
   488  	validationErr := s.validateStreamId(frame.StreamId)
   489  	if validationErr != nil {
   490  		go func() {
   491  			resetErr := s.sendResetFrame(spdy.ProtocolError, frame.StreamId)
   492  			if resetErr != nil {
   493  				debugMessage("reset error: %s", resetErr)
   494  			}
   495  		}()
   496  		return false
   497  	}
   498  	return true
   499  }
   500  
   501  func (s *Connection) handleStreamFrame(frame *spdy.SynStreamFrame, newHandler StreamHandler) error {
   502  	stream, ok := s.getStream(frame.StreamId)
   503  	if !ok {
   504  		return fmt.Errorf("Missing stream: %d", frame.StreamId)
   505  	}
   506  
   507  	newHandler(stream)
   508  
   509  	return nil
   510  }
   511  
   512  func (s *Connection) handleReplyFrame(frame *spdy.SynReplyFrame) error {
   513  	debugMessage("(%p) Reply frame received for %d", s, frame.StreamId)
   514  	stream, streamOk := s.getStream(frame.StreamId)
   515  	if !streamOk {
   516  		debugMessage("Reply frame gone away for %d", frame.StreamId)
   517  		// Stream has already gone away
   518  		return nil
   519  	}
   520  	if stream.replied {
   521  		// Stream has already received reply
   522  		return nil
   523  	}
   524  	stream.replied = true
   525  
   526  	// TODO Check for error
   527  	if (frame.CFHeader.Flags & spdy.ControlFlagFin) != 0x00 {
   528  		s.remoteStreamFinish(stream)
   529  	}
   530  
   531  	close(stream.startChan)
   532  
   533  	return nil
   534  }
   535  
   536  func (s *Connection) handleResetFrame(frame *spdy.RstStreamFrame) error {
   537  	stream, streamOk := s.getStream(frame.StreamId)
   538  	if !streamOk {
   539  		// Stream has already been removed
   540  		return nil
   541  	}
   542  	s.removeStream(stream)
   543  	stream.closeRemoteChannels()
   544  
   545  	if !stream.replied {
   546  		stream.replied = true
   547  		stream.startChan <- ErrReset
   548  		close(stream.startChan)
   549  	}
   550  
   551  	stream.finishLock.Lock()
   552  	stream.finished = true
   553  	stream.finishLock.Unlock()
   554  
   555  	return nil
   556  }
   557  
   558  func (s *Connection) handleHeaderFrame(frame *spdy.HeadersFrame) error {
   559  	stream, streamOk := s.getStream(frame.StreamId)
   560  	if !streamOk {
   561  		// Stream has already gone away
   562  		return nil
   563  	}
   564  	if !stream.replied {
   565  		// No reply received...Protocol error?
   566  		return nil
   567  	}
   568  
   569  	// TODO limit headers while not blocking (use buffered chan or goroutine?)
   570  	select {
   571  	case <-stream.closeChan:
   572  		return nil
   573  	case stream.headerChan <- frame.Headers:
   574  	}
   575  
   576  	if (frame.CFHeader.Flags & spdy.ControlFlagFin) != 0x00 {
   577  		s.remoteStreamFinish(stream)
   578  	}
   579  
   580  	return nil
   581  }
   582  
   583  func (s *Connection) handleDataFrame(frame *spdy.DataFrame) error {
   584  	debugMessage("(%p) Data frame received for %d", s, frame.StreamId)
   585  	stream, streamOk := s.getStream(frame.StreamId)
   586  	if !streamOk {
   587  		debugMessage("(%p) Data frame gone away for %d", s, frame.StreamId)
   588  		// Stream has already gone away
   589  		return nil
   590  	}
   591  	if !stream.replied {
   592  		debugMessage("(%p) Data frame not replied %d", s, frame.StreamId)
   593  		// No reply received...Protocol error?
   594  		return nil
   595  	}
   596  
   597  	debugMessage("(%p) (%d) Data frame handling", stream, stream.streamId)
   598  	if len(frame.Data) > 0 {
   599  		stream.dataLock.RLock()
   600  		select {
   601  		case <-stream.closeChan:
   602  			debugMessage("(%p) (%d) Data frame not sent (stream shut down)", stream, stream.streamId)
   603  		case stream.dataChan <- frame.Data:
   604  			debugMessage("(%p) (%d) Data frame sent", stream, stream.streamId)
   605  		}
   606  		stream.dataLock.RUnlock()
   607  	}
   608  	if (frame.Flags & spdy.DataFlagFin) != 0x00 {
   609  		s.remoteStreamFinish(stream)
   610  	}
   611  	return nil
   612  }
   613  
   614  func (s *Connection) handlePingFrame(frame *spdy.PingFrame) error {
   615  	if s.pingId&0x01 != frame.Id&0x01 {
   616  		return s.framer.WriteFrame(frame)
   617  	}
   618  	pingChan, pingOk := s.pingChans[frame.Id]
   619  	if pingOk {
   620  		close(pingChan)
   621  	}
   622  	return nil
   623  }
   624  
   625  func (s *Connection) handleGoAwayFrame(frame *spdy.GoAwayFrame) error {
   626  	debugMessage("(%p) Go away received", s)
   627  	s.receiveIdLock.Lock()
   628  	if s.goneAway {
   629  		s.receiveIdLock.Unlock()
   630  		return nil
   631  	}
   632  	s.goneAway = true
   633  	s.receiveIdLock.Unlock()
   634  
   635  	if s.lastStreamChan != nil {
   636  		stream, _ := s.getStream(frame.LastGoodStreamId)
   637  		go func() {
   638  			s.lastStreamChan <- stream
   639  		}()
   640  	}
   641  
   642  	// Do not block frame handler waiting for closure
   643  	go s.shutdown(s.goAwayTimeout)
   644  
   645  	return nil
   646  }
   647  
   648  func (s *Connection) remoteStreamFinish(stream *Stream) {
   649  	stream.closeRemoteChannels()
   650  
   651  	stream.finishLock.Lock()
   652  	if stream.finished {
   653  		// Stream is fully closed, cleanup
   654  		s.removeStream(stream)
   655  	}
   656  	stream.finishLock.Unlock()
   657  }
   658  
   659  // CreateStream creates a new spdy stream using the parameters for
   660  // creating the stream frame.  The stream frame will be sent upon
   661  // calling this function, however this function does not wait for
   662  // the reply frame.  If waiting for the reply is desired, use
   663  // the stream Wait or WaitTimeout function on the stream returned
   664  // by this function.
   665  func (s *Connection) CreateStream(headers http.Header, parent *Stream, fin bool) (*Stream, error) {
   666  	// MUST synchronize stream creation (all the way to writing the frame)
   667  	// as stream IDs **MUST** increase monotonically.
   668  	s.nextIdLock.Lock()
   669  	defer s.nextIdLock.Unlock()
   670  
   671  	streamId := s.getNextStreamId()
   672  	if streamId == 0 {
   673  		return nil, fmt.Errorf("Unable to get new stream id")
   674  	}
   675  
   676  	stream := &Stream{
   677  		streamId:   streamId,
   678  		parent:     parent,
   679  		conn:       s,
   680  		startChan:  make(chan error),
   681  		headers:    headers,
   682  		dataChan:   make(chan []byte),
   683  		headerChan: make(chan http.Header),
   684  		closeChan:  make(chan bool),
   685  	}
   686  
   687  	debugMessage("(%p) (%p) Create stream", s, stream)
   688  
   689  	s.addStream(stream)
   690  
   691  	return stream, s.sendStream(stream, fin)
   692  }
   693  
   694  func (s *Connection) shutdown(closeTimeout time.Duration) {
   695  	// TODO Ensure this isn't called multiple times
   696  	s.shutdownLock.Lock()
   697  	if s.hasShutdown {
   698  		s.shutdownLock.Unlock()
   699  		return
   700  	}
   701  	s.hasShutdown = true
   702  	s.shutdownLock.Unlock()
   703  
   704  	var timeout <-chan time.Time
   705  	if closeTimeout > time.Duration(0) {
   706  		timeout = time.After(closeTimeout)
   707  	}
   708  	streamsClosed := make(chan bool)
   709  
   710  	go func() {
   711  		s.streamCond.L.Lock()
   712  		for len(s.streams) > 0 {
   713  			debugMessage("Streams opened: %d, %#v", len(s.streams), s.streams)
   714  			s.streamCond.Wait()
   715  		}
   716  		s.streamCond.L.Unlock()
   717  		close(streamsClosed)
   718  	}()
   719  
   720  	var err error
   721  	select {
   722  	case <-streamsClosed:
   723  		// No active streams, close should be safe
   724  		err = s.conn.Close()
   725  	case <-timeout:
   726  		// Force ungraceful close
   727  		err = s.conn.Close()
   728  		// Wait for cleanup to clear active streams
   729  		<-streamsClosed
   730  	}
   731  
   732  	if err != nil {
   733  		duration := 10 * time.Minute
   734  		time.AfterFunc(duration, func() {
   735  			select {
   736  			case err, ok := <-s.shutdownChan:
   737  				if ok {
   738  					debugMessage("Unhandled close error after %s: %s", duration, err)
   739  				}
   740  			default:
   741  			}
   742  		})
   743  		s.shutdownChan <- err
   744  	}
   745  	close(s.shutdownChan)
   746  }
   747  
   748  // Closes spdy connection by sending GoAway frame and initiating shutdown
   749  func (s *Connection) Close() error {
   750  	s.receiveIdLock.Lock()
   751  	if s.goneAway {
   752  		s.receiveIdLock.Unlock()
   753  		return nil
   754  	}
   755  	s.goneAway = true
   756  	s.receiveIdLock.Unlock()
   757  
   758  	var lastStreamId spdy.StreamId
   759  	if s.receivedStreamId > 2 {
   760  		lastStreamId = s.receivedStreamId - 2
   761  	}
   762  
   763  	goAwayFrame := &spdy.GoAwayFrame{
   764  		LastGoodStreamId: lastStreamId,
   765  		Status:           spdy.GoAwayOK,
   766  	}
   767  
   768  	err := s.framer.WriteFrame(goAwayFrame)
   769  	go s.shutdown(s.closeTimeout)
   770  	if err != nil {
   771  		return err
   772  	}
   773  
   774  	return nil
   775  }
   776  
   777  // CloseWait closes the connection and waits for shutdown
   778  // to finish.  Note the underlying network Connection
   779  // is not closed until the end of shutdown.
   780  func (s *Connection) CloseWait() error {
   781  	closeErr := s.Close()
   782  	if closeErr != nil {
   783  		return closeErr
   784  	}
   785  	shutdownErr, ok := <-s.shutdownChan
   786  	if ok {
   787  		return shutdownErr
   788  	}
   789  	return nil
   790  }
   791  
   792  // Wait waits for the connection to finish shutdown or for
   793  // the wait timeout duration to expire.  This needs to be
   794  // called either after Close has been called or the GOAWAYFRAME
   795  // has been received.  If the wait timeout is 0, this function
   796  // will block until shutdown finishes.  If wait is never called
   797  // and a shutdown error occurs, that error will be logged as an
   798  // unhandled error.
   799  func (s *Connection) Wait(waitTimeout time.Duration) error {
   800  	var timeout <-chan time.Time
   801  	if waitTimeout > time.Duration(0) {
   802  		timeout = time.After(waitTimeout)
   803  	}
   804  
   805  	select {
   806  	case err, ok := <-s.shutdownChan:
   807  		if ok {
   808  			return err
   809  		}
   810  	case <-timeout:
   811  		return ErrTimeout
   812  	}
   813  	return nil
   814  }
   815  
   816  // NotifyClose registers a channel to be called when the remote
   817  // peer inidicates connection closure.  The last stream to be
   818  // received by the remote will be sent on the channel.  The notify
   819  // timeout will determine the duration between go away received
   820  // and the connection being closed.
   821  func (s *Connection) NotifyClose(c chan<- *Stream, timeout time.Duration) {
   822  	s.goAwayTimeout = timeout
   823  	s.lastStreamChan = c
   824  }
   825  
   826  // SetCloseTimeout sets the amount of time close will wait for
   827  // streams to finish before terminating the underlying network
   828  // connection.  Setting the timeout to 0 will cause close to
   829  // wait forever, which is the default.
   830  func (s *Connection) SetCloseTimeout(timeout time.Duration) {
   831  	s.closeTimeout = timeout
   832  }
   833  
   834  // SetIdleTimeout sets the amount of time the connection may sit idle before
   835  // it is forcefully terminated.
   836  func (s *Connection) SetIdleTimeout(timeout time.Duration) {
   837  	s.framer.setIdleTimeout(timeout)
   838  }
   839  
   840  func (s *Connection) sendHeaders(headers http.Header, stream *Stream, fin bool) error {
   841  	var flags spdy.ControlFlags
   842  	if fin {
   843  		flags = spdy.ControlFlagFin
   844  	}
   845  
   846  	headerFrame := &spdy.HeadersFrame{
   847  		StreamId: stream.streamId,
   848  		Headers:  headers,
   849  		CFHeader: spdy.ControlFrameHeader{Flags: flags},
   850  	}
   851  
   852  	return s.framer.WriteFrame(headerFrame)
   853  }
   854  
   855  func (s *Connection) sendReply(headers http.Header, stream *Stream, fin bool) error {
   856  	var flags spdy.ControlFlags
   857  	if fin {
   858  		flags = spdy.ControlFlagFin
   859  	}
   860  
   861  	replyFrame := &spdy.SynReplyFrame{
   862  		StreamId: stream.streamId,
   863  		Headers:  headers,
   864  		CFHeader: spdy.ControlFrameHeader{Flags: flags},
   865  	}
   866  
   867  	return s.framer.WriteFrame(replyFrame)
   868  }
   869  
   870  func (s *Connection) sendResetFrame(status spdy.RstStreamStatus, streamId spdy.StreamId) error {
   871  	resetFrame := &spdy.RstStreamFrame{
   872  		StreamId: streamId,
   873  		Status:   status,
   874  	}
   875  
   876  	return s.framer.WriteFrame(resetFrame)
   877  }
   878  
   879  func (s *Connection) sendReset(status spdy.RstStreamStatus, stream *Stream) error {
   880  	return s.sendResetFrame(status, stream.streamId)
   881  }
   882  
   883  func (s *Connection) sendStream(stream *Stream, fin bool) error {
   884  	var flags spdy.ControlFlags
   885  	if fin {
   886  		flags = spdy.ControlFlagFin
   887  		stream.finished = true
   888  	}
   889  
   890  	var parentId spdy.StreamId
   891  	if stream.parent != nil {
   892  		parentId = stream.parent.streamId
   893  	}
   894  
   895  	streamFrame := &spdy.SynStreamFrame{
   896  		StreamId:             spdy.StreamId(stream.streamId),
   897  		AssociatedToStreamId: spdy.StreamId(parentId),
   898  		Headers:              stream.headers,
   899  		CFHeader:             spdy.ControlFrameHeader{Flags: flags},
   900  	}
   901  
   902  	return s.framer.WriteFrame(streamFrame)
   903  }
   904  
   905  // getNextStreamId returns the next sequential id
   906  // every call should produce a unique value or an error
   907  func (s *Connection) getNextStreamId() spdy.StreamId {
   908  	sid := s.nextStreamId
   909  	if sid > 0x7fffffff {
   910  		return 0
   911  	}
   912  	s.nextStreamId = s.nextStreamId + 2
   913  	return sid
   914  }
   915  
   916  // PeekNextStreamId returns the next sequential id and keeps the next id untouched
   917  func (s *Connection) PeekNextStreamId() spdy.StreamId {
   918  	sid := s.nextStreamId
   919  	return sid
   920  }
   921  
   922  func (s *Connection) validateStreamId(rid spdy.StreamId) error {
   923  	if rid > 0x7fffffff || rid < s.receivedStreamId {
   924  		return ErrInvalidStreamId
   925  	}
   926  	s.receivedStreamId = rid + 2
   927  	return nil
   928  }
   929  
   930  func (s *Connection) addStream(stream *Stream) {
   931  	s.streamCond.L.Lock()
   932  	s.streams[stream.streamId] = stream
   933  	debugMessage("(%p) (%p) Stream added, broadcasting: %d", s, stream, stream.streamId)
   934  	s.streamCond.Broadcast()
   935  	s.streamCond.L.Unlock()
   936  }
   937  
   938  func (s *Connection) removeStream(stream *Stream) {
   939  	s.streamCond.L.Lock()
   940  	delete(s.streams, stream.streamId)
   941  	debugMessage("(%p) (%p) Stream removed, broadcasting: %d", s, stream, stream.streamId)
   942  	s.streamCond.Broadcast()
   943  	s.streamCond.L.Unlock()
   944  }
   945  
   946  func (s *Connection) getStream(streamId spdy.StreamId) (stream *Stream, ok bool) {
   947  	s.streamLock.RLock()
   948  	stream, ok = s.streams[streamId]
   949  	s.streamLock.RUnlock()
   950  	return
   951  }
   952  
   953  // FindStream looks up the given stream id and either waits for the
   954  // stream to be found or returns nil if the stream id is no longer
   955  // valid.
   956  func (s *Connection) FindStream(streamId uint32) *Stream {
   957  	var stream *Stream
   958  	var ok bool
   959  	s.streamCond.L.Lock()
   960  	stream, ok = s.streams[spdy.StreamId(streamId)]
   961  	debugMessage("(%p) Found stream %d? %t", s, spdy.StreamId(streamId), ok)
   962  	for !ok && streamId >= uint32(s.receivedStreamId) {
   963  		s.streamCond.Wait()
   964  		stream, ok = s.streams[spdy.StreamId(streamId)]
   965  	}
   966  	s.streamCond.L.Unlock()
   967  	return stream
   968  }
   969  
   970  func (s *Connection) CloseChan() <-chan bool {
   971  	return s.closeChan
   972  }
   973  

View as plain text