...

Source file src/github.com/containerd/ttrpc/client.go

Documentation: github.com/containerd/ttrpc

     1  /*
     2     Copyright The containerd Authors.
     3  
     4     Licensed under the Apache License, Version 2.0 (the "License");
     5     you may not use this file except in compliance with the License.
     6     You may obtain a copy of the License at
     7  
     8         http://www.apache.org/licenses/LICENSE-2.0
     9  
    10     Unless required by applicable law or agreed to in writing, software
    11     distributed under the License is distributed on an "AS IS" BASIS,
    12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13     See the License for the specific language governing permissions and
    14     limitations under the License.
    15  */
    16  
    17  package ttrpc
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"net"
    25  	"strings"
    26  	"sync"
    27  	"syscall"
    28  	"time"
    29  
    30  	"github.com/sirupsen/logrus"
    31  	"google.golang.org/grpc/codes"
    32  	"google.golang.org/grpc/status"
    33  	"google.golang.org/protobuf/proto"
    34  )
    35  
    36  // Client for a ttrpc server
    37  type Client struct {
    38  	codec   codec
    39  	conn    net.Conn
    40  	channel *channel
    41  
    42  	streamLock   sync.RWMutex
    43  	streams      map[streamID]*stream
    44  	nextStreamID streamID
    45  	sendLock     sync.Mutex
    46  
    47  	ctx    context.Context
    48  	closed func()
    49  
    50  	closeOnce       sync.Once
    51  	userCloseFunc   func()
    52  	userCloseWaitCh chan struct{}
    53  
    54  	interceptor UnaryClientInterceptor
    55  }
    56  
    57  // ClientOpts configures a client
    58  type ClientOpts func(c *Client)
    59  
    60  // WithOnClose sets the close func whenever the client's Close() method is called
    61  func WithOnClose(onClose func()) ClientOpts {
    62  	return func(c *Client) {
    63  		c.userCloseFunc = onClose
    64  	}
    65  }
    66  
    67  // WithUnaryClientInterceptor sets the provided client interceptor
    68  func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
    69  	return func(c *Client) {
    70  		c.interceptor = i
    71  	}
    72  }
    73  
    74  // NewClient creates a new ttrpc client using the given connection
    75  func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
    76  	ctx, cancel := context.WithCancel(context.Background())
    77  	channel := newChannel(conn)
    78  	c := &Client{
    79  		codec:           codec{},
    80  		conn:            conn,
    81  		channel:         channel,
    82  		streams:         make(map[streamID]*stream),
    83  		nextStreamID:    1,
    84  		closed:          cancel,
    85  		ctx:             ctx,
    86  		userCloseFunc:   func() {},
    87  		userCloseWaitCh: make(chan struct{}),
    88  		interceptor:     defaultClientInterceptor,
    89  	}
    90  
    91  	for _, o := range opts {
    92  		o(c)
    93  	}
    94  
    95  	go c.run()
    96  	return c
    97  }
    98  
    99  func (c *Client) send(sid uint32, mt messageType, flags uint8, b []byte) error {
   100  	c.sendLock.Lock()
   101  	defer c.sendLock.Unlock()
   102  	return c.channel.send(sid, mt, flags, b)
   103  }
   104  
   105  // Call makes a unary request and returns with response
   106  func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
   107  	payload, err := c.codec.Marshal(req)
   108  	if err != nil {
   109  		return err
   110  	}
   111  
   112  	var (
   113  		creq = &Request{
   114  			Service: service,
   115  			Method:  method,
   116  			Payload: payload,
   117  			// TODO: metadata from context
   118  		}
   119  
   120  		cresp = &Response{}
   121  	)
   122  
   123  	if metadata, ok := GetMetadata(ctx); ok {
   124  		metadata.setRequest(creq)
   125  	}
   126  
   127  	if dl, ok := ctx.Deadline(); ok {
   128  		creq.TimeoutNano = time.Until(dl).Nanoseconds()
   129  	}
   130  
   131  	info := &UnaryClientInfo{
   132  		FullMethod: fullPath(service, method),
   133  	}
   134  	if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
   135  		return err
   136  	}
   137  
   138  	if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
   139  		return err
   140  	}
   141  
   142  	if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) {
   143  		return status.ErrorProto(cresp.Status)
   144  	}
   145  	return nil
   146  }
   147  
   148  // StreamDesc describes the stream properties, whether the stream has
   149  // a streaming client, a streaming server, or both
   150  type StreamDesc struct {
   151  	StreamingClient bool
   152  	StreamingServer bool
   153  }
   154  
   155  // ClientStream is used to send or recv messages on the underlying stream
   156  type ClientStream interface {
   157  	CloseSend() error
   158  	SendMsg(m interface{}) error
   159  	RecvMsg(m interface{}) error
   160  }
   161  
   162  type clientStream struct {
   163  	ctx          context.Context
   164  	s            *stream
   165  	c            *Client
   166  	desc         *StreamDesc
   167  	localClosed  bool
   168  	remoteClosed bool
   169  }
   170  
   171  func (cs *clientStream) CloseSend() error {
   172  	if !cs.desc.StreamingClient {
   173  		return fmt.Errorf("%w: cannot close non-streaming client", ErrProtocol)
   174  	}
   175  	if cs.localClosed {
   176  		return ErrStreamClosed
   177  	}
   178  	err := cs.s.send(messageTypeData, flagRemoteClosed|flagNoData, nil)
   179  	if err != nil {
   180  		return filterCloseErr(err)
   181  	}
   182  	cs.localClosed = true
   183  	return nil
   184  }
   185  
   186  func (cs *clientStream) SendMsg(m interface{}) error {
   187  	if !cs.desc.StreamingClient {
   188  		return fmt.Errorf("%w: cannot send data from non-streaming client", ErrProtocol)
   189  	}
   190  	if cs.localClosed {
   191  		return ErrStreamClosed
   192  	}
   193  
   194  	var (
   195  		payload []byte
   196  		err     error
   197  	)
   198  	if m != nil {
   199  		payload, err = cs.c.codec.Marshal(m)
   200  		if err != nil {
   201  			return err
   202  		}
   203  	}
   204  
   205  	err = cs.s.send(messageTypeData, 0, payload)
   206  	if err != nil {
   207  		return filterCloseErr(err)
   208  	}
   209  
   210  	return nil
   211  }
   212  
   213  func (cs *clientStream) RecvMsg(m interface{}) error {
   214  	if cs.remoteClosed {
   215  		return io.EOF
   216  	}
   217  
   218  	var msg *streamMessage
   219  	select {
   220  	case <-cs.ctx.Done():
   221  		return cs.ctx.Err()
   222  	case <-cs.s.recvClose:
   223  		// If recv has a pending message, process that first
   224  		select {
   225  		case msg = <-cs.s.recv:
   226  		default:
   227  			return cs.s.recvErr
   228  		}
   229  	case msg = <-cs.s.recv:
   230  	}
   231  
   232  	if msg.header.Type == messageTypeResponse {
   233  		resp := &Response{}
   234  		err := proto.Unmarshal(msg.payload[:msg.header.Length], resp)
   235  		// return the payload buffer for reuse
   236  		cs.c.channel.putmbuf(msg.payload)
   237  		if err != nil {
   238  			return err
   239  		}
   240  
   241  		if err := cs.c.codec.Unmarshal(resp.Payload, m); err != nil {
   242  			return err
   243  		}
   244  
   245  		if resp.Status != nil && resp.Status.Code != int32(codes.OK) {
   246  			return status.ErrorProto(resp.Status)
   247  		}
   248  
   249  		cs.c.deleteStream(cs.s)
   250  		cs.remoteClosed = true
   251  
   252  		return nil
   253  	} else if msg.header.Type == messageTypeData {
   254  		if !cs.desc.StreamingServer {
   255  			cs.c.deleteStream(cs.s)
   256  			cs.remoteClosed = true
   257  			return fmt.Errorf("received data from non-streaming server: %w", ErrProtocol)
   258  		}
   259  		if msg.header.Flags&flagRemoteClosed == flagRemoteClosed {
   260  			cs.c.deleteStream(cs.s)
   261  			cs.remoteClosed = true
   262  
   263  			if msg.header.Flags&flagNoData == flagNoData {
   264  				return io.EOF
   265  			}
   266  		}
   267  
   268  		err := cs.c.codec.Unmarshal(msg.payload[:msg.header.Length], m)
   269  		cs.c.channel.putmbuf(msg.payload)
   270  		if err != nil {
   271  			return err
   272  		}
   273  		return nil
   274  	}
   275  
   276  	return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
   277  }
   278  
   279  // Close closes the ttrpc connection and underlying connection
   280  func (c *Client) Close() error {
   281  	c.closeOnce.Do(func() {
   282  		c.closed()
   283  
   284  		c.conn.Close()
   285  	})
   286  	return nil
   287  }
   288  
   289  // UserOnCloseWait is used to blocks untils the user's on-close callback
   290  // finishes.
   291  func (c *Client) UserOnCloseWait(ctx context.Context) error {
   292  	select {
   293  	case <-c.userCloseWaitCh:
   294  		return nil
   295  	case <-ctx.Done():
   296  		return ctx.Err()
   297  	}
   298  }
   299  
   300  func (c *Client) run() {
   301  	err := c.receiveLoop()
   302  	c.Close()
   303  	c.cleanupStreams(err)
   304  
   305  	c.userCloseFunc()
   306  	close(c.userCloseWaitCh)
   307  }
   308  
   309  func (c *Client) receiveLoop() error {
   310  	for {
   311  		select {
   312  		case <-c.ctx.Done():
   313  			return ErrClosed
   314  		default:
   315  			var (
   316  				msg = &streamMessage{}
   317  				err error
   318  			)
   319  
   320  			msg.header, msg.payload, err = c.channel.recv()
   321  			if err != nil {
   322  				_, ok := status.FromError(err)
   323  				if !ok {
   324  					// treat all errors that are not an rpc status as terminal.
   325  					// all others poison the connection.
   326  					return filterCloseErr(err)
   327  				}
   328  			}
   329  			sid := streamID(msg.header.StreamID)
   330  			s := c.getStream(sid)
   331  			if s == nil {
   332  				logrus.WithField("stream", sid).Errorf("ttrpc: received message on inactive stream")
   333  				continue
   334  			}
   335  
   336  			if err != nil {
   337  				s.closeWithError(err)
   338  			} else {
   339  				if err := s.receive(c.ctx, msg); err != nil {
   340  					logrus.WithError(err).WithField("stream", sid).Errorf("ttrpc: failed to handle message")
   341  				}
   342  			}
   343  		}
   344  	}
   345  }
   346  
   347  // createStream creates a new stream and registers it with the client
   348  // Introduce stream types for multiple or single response
   349  func (c *Client) createStream(flags uint8, b []byte) (*stream, error) {
   350  	c.streamLock.Lock()
   351  
   352  	// Check if closed since lock acquired to prevent adding
   353  	// anything after cleanup completes
   354  	select {
   355  	case <-c.ctx.Done():
   356  		c.streamLock.Unlock()
   357  		return nil, ErrClosed
   358  	default:
   359  	}
   360  
   361  	// Stream ID should be allocated at same time
   362  	s := newStream(c.nextStreamID, c)
   363  	c.streams[s.id] = s
   364  	c.nextStreamID = c.nextStreamID + 2
   365  
   366  	c.sendLock.Lock()
   367  	defer c.sendLock.Unlock()
   368  	c.streamLock.Unlock()
   369  
   370  	if err := c.channel.send(uint32(s.id), messageTypeRequest, flags, b); err != nil {
   371  		return s, filterCloseErr(err)
   372  	}
   373  
   374  	return s, nil
   375  }
   376  
   377  func (c *Client) deleteStream(s *stream) {
   378  	c.streamLock.Lock()
   379  	delete(c.streams, s.id)
   380  	c.streamLock.Unlock()
   381  	s.closeWithError(nil)
   382  }
   383  
   384  func (c *Client) getStream(sid streamID) *stream {
   385  	c.streamLock.RLock()
   386  	s := c.streams[sid]
   387  	c.streamLock.RUnlock()
   388  	return s
   389  }
   390  
   391  func (c *Client) cleanupStreams(err error) {
   392  	c.streamLock.Lock()
   393  	defer c.streamLock.Unlock()
   394  
   395  	for sid, s := range c.streams {
   396  		s.closeWithError(err)
   397  		delete(c.streams, sid)
   398  	}
   399  }
   400  
   401  // filterCloseErr rewrites EOF and EPIPE errors to ErrClosed. Use when
   402  // returning from call or handling errors from main read loop.
   403  //
   404  // This purposely ignores errors with a wrapped cause.
   405  func filterCloseErr(err error) error {
   406  	switch {
   407  	case err == nil:
   408  		return nil
   409  	case err == io.EOF:
   410  		return ErrClosed
   411  	case errors.Is(err, io.ErrClosedPipe):
   412  		return ErrClosed
   413  	case errors.Is(err, io.EOF):
   414  		return ErrClosed
   415  	case strings.Contains(err.Error(), "use of closed network connection"):
   416  		return ErrClosed
   417  	default:
   418  		// if we have an epipe on a write or econnreset on a read , we cast to errclosed
   419  		var oerr *net.OpError
   420  		if errors.As(err, &oerr) {
   421  			if (oerr.Op == "write" && errors.Is(err, syscall.EPIPE)) ||
   422  				(oerr.Op == "read" && errors.Is(err, syscall.ECONNRESET)) {
   423  				return ErrClosed
   424  			}
   425  		}
   426  	}
   427  
   428  	return err
   429  }
   430  
   431  // NewStream creates a new stream with the given stream descriptor to the
   432  // specified service and method. If not a streaming client, the request object
   433  // may be provided.
   434  func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, method string, req interface{}) (ClientStream, error) {
   435  	var payload []byte
   436  	if req != nil {
   437  		var err error
   438  		payload, err = c.codec.Marshal(req)
   439  		if err != nil {
   440  			return nil, err
   441  		}
   442  	}
   443  
   444  	request := &Request{
   445  		Service: service,
   446  		Method:  method,
   447  		Payload: payload,
   448  		// TODO: metadata from context
   449  	}
   450  	p, err := c.codec.Marshal(request)
   451  	if err != nil {
   452  		return nil, err
   453  	}
   454  
   455  	var flags uint8
   456  	if desc.StreamingClient {
   457  		flags = flagRemoteOpen
   458  	} else {
   459  		flags = flagRemoteClosed
   460  	}
   461  	s, err := c.createStream(flags, p)
   462  	if err != nil {
   463  		return nil, err
   464  	}
   465  
   466  	return &clientStream{
   467  		ctx:  ctx,
   468  		s:    s,
   469  		c:    c,
   470  		desc: desc,
   471  	}, nil
   472  }
   473  
   474  func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
   475  	p, err := c.codec.Marshal(req)
   476  	if err != nil {
   477  		return err
   478  	}
   479  
   480  	s, err := c.createStream(0, p)
   481  	if err != nil {
   482  		return err
   483  	}
   484  	defer c.deleteStream(s)
   485  
   486  	var msg *streamMessage
   487  	select {
   488  	case <-ctx.Done():
   489  		return ctx.Err()
   490  	case <-c.ctx.Done():
   491  		return ErrClosed
   492  	case <-s.recvClose:
   493  		// If recv has a pending message, process that first
   494  		select {
   495  		case msg = <-s.recv:
   496  		default:
   497  			return s.recvErr
   498  		}
   499  	case msg = <-s.recv:
   500  	}
   501  
   502  	if msg.header.Type == messageTypeResponse {
   503  		err = proto.Unmarshal(msg.payload[:msg.header.Length], resp)
   504  	} else {
   505  		err = fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
   506  	}
   507  
   508  	// return the payload buffer for reuse
   509  	c.channel.putmbuf(msg.payload)
   510  
   511  	return err
   512  }
   513  

View as plain text