...

Source file src/github.com/containerd/ttrpc/server.go

Documentation: github.com/containerd/ttrpc

     1  /*
     2     Copyright The containerd Authors.
     3  
     4     Licensed under the Apache License, Version 2.0 (the "License");
     5     you may not use this file except in compliance with the License.
     6     You may obtain a copy of the License at
     7  
     8         http://www.apache.org/licenses/LICENSE-2.0
     9  
    10     Unless required by applicable law or agreed to in writing, software
    11     distributed under the License is distributed on an "AS IS" BASIS,
    12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13     See the License for the specific language governing permissions and
    14     limitations under the License.
    15  */
    16  
    17  package ttrpc
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"io"
    23  	"math/rand"
    24  	"net"
    25  	"sync"
    26  	"sync/atomic"
    27  	"syscall"
    28  	"time"
    29  
    30  	"github.com/sirupsen/logrus"
    31  	"google.golang.org/grpc/codes"
    32  	"google.golang.org/grpc/status"
    33  )
    34  
    35  type Server struct {
    36  	config   *serverConfig
    37  	services *serviceSet
    38  	codec    codec
    39  
    40  	mu          sync.Mutex
    41  	listeners   map[net.Listener]struct{}
    42  	connections map[*serverConn]struct{} // all connections to current state
    43  	done        chan struct{}            // marks point at which we stop serving requests
    44  }
    45  
    46  func NewServer(opts ...ServerOpt) (*Server, error) {
    47  	config := &serverConfig{}
    48  	for _, opt := range opts {
    49  		if err := opt(config); err != nil {
    50  			return nil, err
    51  		}
    52  	}
    53  	if config.interceptor == nil {
    54  		config.interceptor = defaultServerInterceptor
    55  	}
    56  
    57  	return &Server{
    58  		config:      config,
    59  		services:    newServiceSet(config.interceptor),
    60  		done:        make(chan struct{}),
    61  		listeners:   make(map[net.Listener]struct{}),
    62  		connections: make(map[*serverConn]struct{}),
    63  	}, nil
    64  }
    65  
    66  // Register registers a map of methods to method handlers
    67  // TODO: Remove in 2.0, does not support streams
    68  func (s *Server) Register(name string, methods map[string]Method) {
    69  	s.services.register(name, &ServiceDesc{Methods: methods})
    70  }
    71  
    72  func (s *Server) RegisterService(name string, desc *ServiceDesc) {
    73  	s.services.register(name, desc)
    74  }
    75  
    76  func (s *Server) Serve(ctx context.Context, l net.Listener) error {
    77  	s.addListener(l)
    78  	defer s.closeListener(l)
    79  
    80  	var (
    81  		backoff    time.Duration
    82  		handshaker = s.config.handshaker
    83  	)
    84  
    85  	if handshaker == nil {
    86  		handshaker = handshakerFunc(noopHandshake)
    87  	}
    88  
    89  	for {
    90  		conn, err := l.Accept()
    91  		if err != nil {
    92  			select {
    93  			case <-s.done:
    94  				return ErrServerClosed
    95  			default:
    96  			}
    97  
    98  			if terr, ok := err.(interface {
    99  				Temporary() bool
   100  			}); ok && terr.Temporary() {
   101  				if backoff == 0 {
   102  					backoff = time.Millisecond
   103  				} else {
   104  					backoff *= 2
   105  				}
   106  
   107  				if max := time.Second; backoff > max {
   108  					backoff = max
   109  				}
   110  
   111  				sleep := time.Duration(rand.Int63n(int64(backoff)))
   112  				logrus.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
   113  				time.Sleep(sleep)
   114  				continue
   115  			}
   116  
   117  			return err
   118  		}
   119  
   120  		backoff = 0
   121  
   122  		approved, handshake, err := handshaker.Handshake(ctx, conn)
   123  		if err != nil {
   124  			logrus.WithError(err).Error("ttrpc: refusing connection after handshake")
   125  			conn.Close()
   126  			continue
   127  		}
   128  
   129  		sc, err := s.newConn(approved, handshake)
   130  		if err != nil {
   131  			logrus.WithError(err).Error("ttrpc: create connection failed")
   132  			conn.Close()
   133  			continue
   134  		}
   135  
   136  		go sc.run(ctx)
   137  	}
   138  }
   139  
   140  func (s *Server) Shutdown(ctx context.Context) error {
   141  	s.mu.Lock()
   142  	select {
   143  	case <-s.done:
   144  	default:
   145  		// protected by mutex
   146  		close(s.done)
   147  	}
   148  	lnerr := s.closeListeners()
   149  	s.mu.Unlock()
   150  
   151  	ticker := time.NewTicker(200 * time.Millisecond)
   152  	defer ticker.Stop()
   153  	for {
   154  		s.closeIdleConns()
   155  
   156  		if s.countConnection() == 0 {
   157  			break
   158  		}
   159  
   160  		select {
   161  		case <-ctx.Done():
   162  			return ctx.Err()
   163  		case <-ticker.C:
   164  		}
   165  	}
   166  
   167  	return lnerr
   168  }
   169  
   170  // Close the server without waiting for active connections.
   171  func (s *Server) Close() error {
   172  	s.mu.Lock()
   173  	defer s.mu.Unlock()
   174  
   175  	select {
   176  	case <-s.done:
   177  	default:
   178  		// protected by mutex
   179  		close(s.done)
   180  	}
   181  
   182  	err := s.closeListeners()
   183  	for c := range s.connections {
   184  		c.close()
   185  		delete(s.connections, c)
   186  	}
   187  
   188  	return err
   189  }
   190  
   191  func (s *Server) addListener(l net.Listener) {
   192  	s.mu.Lock()
   193  	defer s.mu.Unlock()
   194  	s.listeners[l] = struct{}{}
   195  }
   196  
   197  func (s *Server) closeListener(l net.Listener) error {
   198  	s.mu.Lock()
   199  	defer s.mu.Unlock()
   200  
   201  	return s.closeListenerLocked(l)
   202  }
   203  
   204  func (s *Server) closeListenerLocked(l net.Listener) error {
   205  	defer delete(s.listeners, l)
   206  	return l.Close()
   207  }
   208  
   209  func (s *Server) closeListeners() error {
   210  	var err error
   211  	for l := range s.listeners {
   212  		if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
   213  			err = cerr
   214  		}
   215  	}
   216  	return err
   217  }
   218  
   219  func (s *Server) addConnection(c *serverConn) error {
   220  	s.mu.Lock()
   221  	defer s.mu.Unlock()
   222  
   223  	select {
   224  	case <-s.done:
   225  		return ErrServerClosed
   226  	default:
   227  	}
   228  
   229  	s.connections[c] = struct{}{}
   230  	return nil
   231  }
   232  
   233  func (s *Server) delConnection(c *serverConn) {
   234  	s.mu.Lock()
   235  	defer s.mu.Unlock()
   236  
   237  	delete(s.connections, c)
   238  }
   239  
   240  func (s *Server) countConnection() int {
   241  	s.mu.Lock()
   242  	defer s.mu.Unlock()
   243  
   244  	return len(s.connections)
   245  }
   246  
   247  func (s *Server) closeIdleConns() {
   248  	s.mu.Lock()
   249  	defer s.mu.Unlock()
   250  
   251  	for c := range s.connections {
   252  		if st, ok := c.getState(); !ok || st == connStateActive {
   253  			continue
   254  		}
   255  		c.close()
   256  		delete(s.connections, c)
   257  	}
   258  }
   259  
   260  type connState int
   261  
   262  const (
   263  	connStateActive = iota + 1 // outstanding requests
   264  	connStateIdle              // no requests
   265  	connStateClosed            // closed connection
   266  )
   267  
   268  func (cs connState) String() string {
   269  	switch cs {
   270  	case connStateActive:
   271  		return "active"
   272  	case connStateIdle:
   273  		return "idle"
   274  	case connStateClosed:
   275  		return "closed"
   276  	default:
   277  		return "unknown"
   278  	}
   279  }
   280  
   281  func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) {
   282  	c := &serverConn{
   283  		server:    s,
   284  		conn:      conn,
   285  		handshake: handshake,
   286  		shutdown:  make(chan struct{}),
   287  	}
   288  	c.setState(connStateIdle)
   289  	if err := s.addConnection(c); err != nil {
   290  		c.close()
   291  		return nil, err
   292  	}
   293  	return c, nil
   294  }
   295  
   296  type serverConn struct {
   297  	server    *Server
   298  	conn      net.Conn
   299  	handshake interface{} // data from handshake, not used for now
   300  	state     atomic.Value
   301  
   302  	shutdownOnce sync.Once
   303  	shutdown     chan struct{} // forced shutdown, used by close
   304  }
   305  
   306  func (c *serverConn) getState() (connState, bool) {
   307  	cs, ok := c.state.Load().(connState)
   308  	return cs, ok
   309  }
   310  
   311  func (c *serverConn) setState(newstate connState) {
   312  	c.state.Store(newstate)
   313  }
   314  
   315  func (c *serverConn) close() error {
   316  	c.shutdownOnce.Do(func() {
   317  		close(c.shutdown)
   318  	})
   319  
   320  	return nil
   321  }
   322  
   323  func (c *serverConn) run(sctx context.Context) {
   324  	type (
   325  		response struct {
   326  			id          uint32
   327  			status      *status.Status
   328  			data        []byte
   329  			closeStream bool
   330  			streaming   bool
   331  		}
   332  	)
   333  
   334  	var (
   335  		ch                     = newChannel(c.conn)
   336  		ctx, cancel            = context.WithCancel(sctx)
   337  		state        connState = connStateIdle
   338  		responses              = make(chan response)
   339  		recvErr                = make(chan error, 1)
   340  		done                   = make(chan struct{})
   341  		streams                = sync.Map{}
   342  		active       int32
   343  		lastStreamID uint32
   344  	)
   345  
   346  	defer c.conn.Close()
   347  	defer cancel()
   348  	defer close(done)
   349  	defer c.server.delConnection(c)
   350  
   351  	sendStatus := func(id uint32, st *status.Status) bool {
   352  		select {
   353  		case responses <- response{
   354  			// even though we've had an invalid stream id, we send it
   355  			// back on the same stream id so the client knows which
   356  			// stream id was bad.
   357  			id:          id,
   358  			status:      st,
   359  			closeStream: true,
   360  		}:
   361  			return true
   362  		case <-c.shutdown:
   363  			return false
   364  		case <-done:
   365  			return false
   366  		}
   367  	}
   368  
   369  	go func(recvErr chan error) {
   370  		defer close(recvErr)
   371  		for {
   372  			select {
   373  			case <-c.shutdown:
   374  				return
   375  			case <-done:
   376  				return
   377  			default: // proceed
   378  			}
   379  
   380  			mh, p, err := ch.recv()
   381  			if err != nil {
   382  				status, ok := status.FromError(err)
   383  				if !ok {
   384  					recvErr <- err
   385  					return
   386  				}
   387  
   388  				// in this case, we send an error for that particular message
   389  				// when the status is defined.
   390  				if !sendStatus(mh.StreamID, status) {
   391  					return
   392  				}
   393  
   394  				continue
   395  			}
   396  
   397  			if mh.StreamID%2 != 1 {
   398  				// enforce odd client initiated identifiers.
   399  				if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
   400  					return
   401  				}
   402  				continue
   403  			}
   404  
   405  			if mh.Type == messageTypeData {
   406  				i, ok := streams.Load(mh.StreamID)
   407  				if !ok {
   408  					if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID is no longer active")) {
   409  						return
   410  					}
   411  				}
   412  				sh := i.(*streamHandler)
   413  				if mh.Flags&flagNoData != flagNoData {
   414  					unmarshal := func(obj interface{}) error {
   415  						err := protoUnmarshal(p, obj)
   416  						ch.putmbuf(p)
   417  						return err
   418  					}
   419  
   420  					if err := sh.data(unmarshal); err != nil {
   421  						if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data handling error: %v", err)) {
   422  							return
   423  						}
   424  					}
   425  				}
   426  
   427  				if mh.Flags&flagRemoteClosed == flagRemoteClosed {
   428  					sh.closeSend()
   429  					if len(p) > 0 {
   430  						if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data close message cannot include data")) {
   431  							return
   432  						}
   433  					}
   434  				}
   435  			} else if mh.Type == messageTypeRequest {
   436  				if mh.StreamID <= lastStreamID {
   437  					// enforce odd client initiated identifiers.
   438  					if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID cannot be re-used and must increment")) {
   439  						return
   440  					}
   441  					continue
   442  
   443  				}
   444  				lastStreamID = mh.StreamID
   445  
   446  				// TODO: Make request type configurable
   447  				// Unmarshaller which takes in a byte array and returns an interface?
   448  				var req Request
   449  				if err := c.server.codec.Unmarshal(p, &req); err != nil {
   450  					ch.putmbuf(p)
   451  					if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
   452  						return
   453  					}
   454  					continue
   455  				}
   456  				ch.putmbuf(p)
   457  
   458  				id := mh.StreamID
   459  				respond := func(status *status.Status, data []byte, streaming, closeStream bool) error {
   460  					select {
   461  					case responses <- response{
   462  						id:          id,
   463  						status:      status,
   464  						data:        data,
   465  						closeStream: closeStream,
   466  						streaming:   streaming,
   467  					}:
   468  					case <-done:
   469  						return ErrClosed
   470  					}
   471  					return nil
   472  				}
   473  				sh, err := c.server.services.handle(ctx, &req, respond)
   474  				if err != nil {
   475  					status, _ := status.FromError(err)
   476  					if !sendStatus(mh.StreamID, status) {
   477  						return
   478  					}
   479  					continue
   480  				}
   481  
   482  				streams.Store(id, sh)
   483  				atomic.AddInt32(&active, 1)
   484  			}
   485  			// TODO: else we must ignore this for future compat. log this?
   486  		}
   487  	}(recvErr)
   488  
   489  	for {
   490  		var (
   491  			newstate connState
   492  			shutdown chan struct{}
   493  		)
   494  
   495  		activeN := atomic.LoadInt32(&active)
   496  		if activeN > 0 {
   497  			newstate = connStateActive
   498  			shutdown = nil
   499  		} else {
   500  			newstate = connStateIdle
   501  			shutdown = c.shutdown // only enable this branch in idle mode
   502  		}
   503  		if newstate != state {
   504  			c.setState(newstate)
   505  			state = newstate
   506  		}
   507  
   508  		select {
   509  		case response := <-responses:
   510  			if !response.streaming || response.status.Code() != codes.OK {
   511  				p, err := c.server.codec.Marshal(&Response{
   512  					Status:  response.status.Proto(),
   513  					Payload: response.data,
   514  				})
   515  				if err != nil {
   516  					logrus.WithError(err).Error("failed marshaling response")
   517  					return
   518  				}
   519  
   520  				if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil {
   521  					logrus.WithError(err).Error("failed sending message on channel")
   522  					return
   523  				}
   524  			} else {
   525  				var flags uint8
   526  				if response.closeStream {
   527  					flags = flagRemoteClosed
   528  				}
   529  				if response.data == nil {
   530  					flags = flags | flagNoData
   531  				}
   532  				if err := ch.send(response.id, messageTypeData, flags, response.data); err != nil {
   533  					logrus.WithError(err).Error("failed sending message on channel")
   534  					return
   535  				}
   536  			}
   537  
   538  			if response.closeStream {
   539  				// The ttrpc protocol currently does not support the case where
   540  				// the server is localClosed but not remoteClosed. Once the server
   541  				// is closing, the whole stream may be considered finished
   542  				streams.Delete(response.id)
   543  				atomic.AddInt32(&active, -1)
   544  			}
   545  		case err := <-recvErr:
   546  			// TODO(stevvooe): Not wildly clear what we should do in this
   547  			// branch. Basically, it means that we are no longer receiving
   548  			// requests due to a terminal error.
   549  			recvErr = nil // connection is now "closing"
   550  			if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, syscall.ECONNRESET) {
   551  				// The client went away and we should stop processing
   552  				// requests, so that the client connection is closed
   553  				return
   554  			}
   555  			logrus.WithError(err).Error("error receiving message")
   556  			// else, initiate shutdown
   557  		case <-shutdown:
   558  			return
   559  		}
   560  	}
   561  }
   562  
   563  var noopFunc = func() {}
   564  
   565  func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
   566  	if len(req.Metadata) > 0 {
   567  		md := MD{}
   568  		md.fromRequest(req)
   569  		ctx = WithMetadata(ctx, md)
   570  	}
   571  
   572  	cancel = noopFunc
   573  	if req.TimeoutNano == 0 {
   574  		return ctx, cancel
   575  	}
   576  
   577  	ctx, cancel = context.WithTimeout(ctx, time.Duration(req.TimeoutNano))
   578  	return ctx, cancel
   579  }
   580  

View as plain text