...

Source file src/go.etcd.io/etcd/server/v3/etcdserver/api/v3rpc/watch.go

Documentation: go.etcd.io/etcd/server/v3/etcdserver/api/v3rpc

     1  // Copyright 2015 The etcd Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package v3rpc
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"math/rand"
    22  	"sync"
    23  	"time"
    24  
    25  	pb "go.etcd.io/etcd/api/v3/etcdserverpb"
    26  	"go.etcd.io/etcd/api/v3/mvccpb"
    27  	"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
    28  	clientv3 "go.etcd.io/etcd/client/v3"
    29  	"go.etcd.io/etcd/server/v3/auth"
    30  	"go.etcd.io/etcd/server/v3/etcdserver"
    31  	"go.etcd.io/etcd/server/v3/mvcc"
    32  
    33  	"go.uber.org/zap"
    34  )
    35  
    36  const minWatchProgressInterval = 100 * time.Millisecond
    37  
    38  type watchServer struct {
    39  	lg *zap.Logger
    40  
    41  	clusterID int64
    42  	memberID  int64
    43  
    44  	maxRequestBytes int
    45  
    46  	sg        etcdserver.RaftStatusGetter
    47  	watchable mvcc.WatchableKV
    48  	ag        AuthGetter
    49  }
    50  
    51  // NewWatchServer returns a new watch server.
    52  func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer {
    53  	srv := &watchServer{
    54  		lg: s.Cfg.Logger,
    55  
    56  		clusterID: int64(s.Cluster().ID()),
    57  		memberID:  int64(s.ID()),
    58  
    59  		maxRequestBytes: int(s.Cfg.MaxRequestBytes + grpcOverheadBytes),
    60  
    61  		sg:        s,
    62  		watchable: s.Watchable(),
    63  		ag:        s,
    64  	}
    65  	if srv.lg == nil {
    66  		srv.lg = zap.NewNop()
    67  	}
    68  	if s.Cfg.WatchProgressNotifyInterval > 0 {
    69  		if s.Cfg.WatchProgressNotifyInterval < minWatchProgressInterval {
    70  			srv.lg.Warn(
    71  				"adjusting watch progress notify interval to minimum period",
    72  				zap.Duration("min-watch-progress-notify-interval", minWatchProgressInterval),
    73  			)
    74  			s.Cfg.WatchProgressNotifyInterval = minWatchProgressInterval
    75  		}
    76  		SetProgressReportInterval(s.Cfg.WatchProgressNotifyInterval)
    77  	}
    78  	return srv
    79  }
    80  
    81  var (
    82  	// External test can read this with GetProgressReportInterval()
    83  	// and change this to a small value to finish fast with
    84  	// SetProgressReportInterval().
    85  	progressReportInterval   = 10 * time.Minute
    86  	progressReportIntervalMu sync.RWMutex
    87  )
    88  
    89  // GetProgressReportInterval returns the current progress report interval (for testing).
    90  func GetProgressReportInterval() time.Duration {
    91  	progressReportIntervalMu.RLock()
    92  	interval := progressReportInterval
    93  	progressReportIntervalMu.RUnlock()
    94  
    95  	// add rand(1/10*progressReportInterval) as jitter so that etcdserver will not
    96  	// send progress notifications to watchers around the same time even when watchers
    97  	// are created around the same time (which is common when a client restarts itself).
    98  	jitter := time.Duration(rand.Int63n(int64(interval) / 10))
    99  
   100  	return interval + jitter
   101  }
   102  
   103  // SetProgressReportInterval updates the current progress report interval (for testing).
   104  func SetProgressReportInterval(newTimeout time.Duration) {
   105  	progressReportIntervalMu.Lock()
   106  	progressReportInterval = newTimeout
   107  	progressReportIntervalMu.Unlock()
   108  }
   109  
   110  // We send ctrl response inside the read loop. We do not want
   111  // send to block read, but we still want ctrl response we sent to
   112  // be serialized. Thus we use a buffered chan to solve the problem.
   113  // A small buffer should be OK for most cases, since we expect the
   114  // ctrl requests are infrequent.
   115  const ctrlStreamBufLen = 16
   116  
   117  // serverWatchStream is an etcd server side stream. It receives requests
   118  // from client side gRPC stream. It receives watch events from mvcc.WatchStream,
   119  // and creates responses that forwarded to gRPC stream.
   120  // It also forwards control message like watch created and canceled.
   121  type serverWatchStream struct {
   122  	lg *zap.Logger
   123  
   124  	clusterID int64
   125  	memberID  int64
   126  
   127  	maxRequestBytes int
   128  
   129  	sg        etcdserver.RaftStatusGetter
   130  	watchable mvcc.WatchableKV
   131  	ag        AuthGetter
   132  
   133  	gRPCStream  pb.Watch_WatchServer
   134  	watchStream mvcc.WatchStream
   135  	ctrlStream  chan *pb.WatchResponse
   136  
   137  	// mu protects progress, prevKV, fragment
   138  	mu sync.RWMutex
   139  	// tracks the watchID that stream might need to send progress to
   140  	// TODO: combine progress and prevKV into a single struct?
   141  	progress map[mvcc.WatchID]bool
   142  	// record watch IDs that need return previous key-value pair
   143  	prevKV map[mvcc.WatchID]bool
   144  	// records fragmented watch IDs
   145  	fragment map[mvcc.WatchID]bool
   146  
   147  	// closec indicates the stream is closed.
   148  	closec chan struct{}
   149  
   150  	// wg waits for the send loop to complete
   151  	wg sync.WaitGroup
   152  }
   153  
   154  func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
   155  	sws := serverWatchStream{
   156  		lg: ws.lg,
   157  
   158  		clusterID: ws.clusterID,
   159  		memberID:  ws.memberID,
   160  
   161  		maxRequestBytes: ws.maxRequestBytes,
   162  
   163  		sg:        ws.sg,
   164  		watchable: ws.watchable,
   165  		ag:        ws.ag,
   166  
   167  		gRPCStream:  stream,
   168  		watchStream: ws.watchable.NewWatchStream(),
   169  		// chan for sending control response like watcher created and canceled.
   170  		ctrlStream: make(chan *pb.WatchResponse, ctrlStreamBufLen),
   171  
   172  		progress: make(map[mvcc.WatchID]bool),
   173  		prevKV:   make(map[mvcc.WatchID]bool),
   174  		fragment: make(map[mvcc.WatchID]bool),
   175  
   176  		closec: make(chan struct{}),
   177  	}
   178  
   179  	sws.wg.Add(1)
   180  	go func() {
   181  		sws.sendLoop()
   182  		sws.wg.Done()
   183  	}()
   184  
   185  	errc := make(chan error, 1)
   186  	// Ideally recvLoop would also use sws.wg to signal its completion
   187  	// but when stream.Context().Done() is closed, the stream's recv
   188  	// may continue to block since it uses a different context, leading to
   189  	// deadlock when calling sws.close().
   190  	go func() {
   191  		if rerr := sws.recvLoop(); rerr != nil {
   192  			if isClientCtxErr(stream.Context().Err(), rerr) {
   193  				sws.lg.Debug("failed to receive watch request from gRPC stream", zap.Error(rerr))
   194  			} else {
   195  				sws.lg.Warn("failed to receive watch request from gRPC stream", zap.Error(rerr))
   196  				streamFailures.WithLabelValues("receive", "watch").Inc()
   197  			}
   198  			errc <- rerr
   199  		}
   200  	}()
   201  
   202  	// TODO: There's a race here. When a stream  is closed (e.g. due to a cancellation),
   203  	// the underlying error (e.g. a gRPC stream error) may be returned and handled
   204  	// through errc if the recv goroutine finishes before the send goroutine.
   205  	// When the recv goroutine wins, the stream error is retained. When recv loses
   206  	// the race, the underlying error is lost (unless the root error is propagated
   207  	// through Context.Err() which is not always the case (as callers have to decide
   208  	// to implement a custom context to do so). The stdlib context package builtins
   209  	// may be insufficient to carry semantically useful errors around and should be
   210  	// revisited.
   211  	select {
   212  	case err = <-errc:
   213  		if err == context.Canceled {
   214  			err = rpctypes.ErrGRPCWatchCanceled
   215  		}
   216  		close(sws.ctrlStream)
   217  	case <-stream.Context().Done():
   218  		err = stream.Context().Err()
   219  		if err == context.Canceled {
   220  			err = rpctypes.ErrGRPCWatchCanceled
   221  		}
   222  	}
   223  
   224  	sws.close()
   225  	return err
   226  }
   227  
   228  func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) error {
   229  	authInfo, err := sws.ag.AuthInfoFromCtx(sws.gRPCStream.Context())
   230  	if err != nil {
   231  		return err
   232  	}
   233  	if authInfo == nil {
   234  		// if auth is enabled, IsRangePermitted() can cause an error
   235  		authInfo = &auth.AuthInfo{}
   236  	}
   237  	return sws.ag.AuthStore().IsRangePermitted(authInfo, wcr.Key, wcr.RangeEnd)
   238  }
   239  
   240  func (sws *serverWatchStream) recvLoop() error {
   241  	for {
   242  		req, err := sws.gRPCStream.Recv()
   243  		if err == io.EOF {
   244  			return nil
   245  		}
   246  		if err != nil {
   247  			return err
   248  		}
   249  
   250  		switch uv := req.RequestUnion.(type) {
   251  		case *pb.WatchRequest_CreateRequest:
   252  			if uv.CreateRequest == nil {
   253  				break
   254  			}
   255  
   256  			creq := uv.CreateRequest
   257  			if len(creq.Key) == 0 {
   258  				// \x00 is the smallest key
   259  				creq.Key = []byte{0}
   260  			}
   261  			if len(creq.RangeEnd) == 0 {
   262  				// force nil since watchstream.Watch distinguishes
   263  				// between nil and []byte{} for single key / >=
   264  				creq.RangeEnd = nil
   265  			}
   266  			if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
   267  				// support  >= key queries
   268  				creq.RangeEnd = []byte{}
   269  			}
   270  
   271  			err := sws.isWatchPermitted(creq)
   272  			if err != nil {
   273  				var cancelReason string
   274  				switch err {
   275  				case auth.ErrInvalidAuthToken:
   276  					cancelReason = rpctypes.ErrGRPCInvalidAuthToken.Error()
   277  				case auth.ErrAuthOldRevision:
   278  					cancelReason = rpctypes.ErrGRPCAuthOldRevision.Error()
   279  				case auth.ErrUserEmpty:
   280  					cancelReason = rpctypes.ErrGRPCUserEmpty.Error()
   281  				default:
   282  					if err != auth.ErrPermissionDenied {
   283  						sws.lg.Error("unexpected error code", zap.Error(err))
   284  					}
   285  					cancelReason = rpctypes.ErrGRPCPermissionDenied.Error()
   286  				}
   287  
   288  				wr := &pb.WatchResponse{
   289  					Header:       sws.newResponseHeader(sws.watchStream.Rev()),
   290  					WatchId:      clientv3.InvalidWatchID,
   291  					Canceled:     true,
   292  					Created:      true,
   293  					CancelReason: cancelReason,
   294  				}
   295  
   296  				select {
   297  				case sws.ctrlStream <- wr:
   298  					continue
   299  				case <-sws.closec:
   300  					return nil
   301  				}
   302  			}
   303  
   304  			filters := FiltersFromRequest(creq)
   305  
   306  			wsrev := sws.watchStream.Rev()
   307  			rev := creq.StartRevision
   308  			if rev == 0 {
   309  				rev = wsrev + 1
   310  			}
   311  			id, err := sws.watchStream.Watch(mvcc.WatchID(creq.WatchId), creq.Key, creq.RangeEnd, rev, filters...)
   312  			if err == nil {
   313  				sws.mu.Lock()
   314  				if creq.ProgressNotify {
   315  					sws.progress[id] = true
   316  				}
   317  				if creq.PrevKv {
   318  					sws.prevKV[id] = true
   319  				}
   320  				if creq.Fragment {
   321  					sws.fragment[id] = true
   322  				}
   323  				sws.mu.Unlock()
   324  			} else {
   325  				id = clientv3.InvalidWatchID
   326  			}
   327  
   328  			wr := &pb.WatchResponse{
   329  				Header:   sws.newResponseHeader(wsrev),
   330  				WatchId:  int64(id),
   331  				Created:  true,
   332  				Canceled: err != nil,
   333  			}
   334  			if err != nil {
   335  				wr.CancelReason = err.Error()
   336  			}
   337  			select {
   338  			case sws.ctrlStream <- wr:
   339  			case <-sws.closec:
   340  				return nil
   341  			}
   342  
   343  		case *pb.WatchRequest_CancelRequest:
   344  			if uv.CancelRequest != nil {
   345  				id := uv.CancelRequest.WatchId
   346  				err := sws.watchStream.Cancel(mvcc.WatchID(id))
   347  				if err == nil {
   348  					sws.ctrlStream <- &pb.WatchResponse{
   349  						Header:   sws.newResponseHeader(sws.watchStream.Rev()),
   350  						WatchId:  id,
   351  						Canceled: true,
   352  					}
   353  					sws.mu.Lock()
   354  					delete(sws.progress, mvcc.WatchID(id))
   355  					delete(sws.prevKV, mvcc.WatchID(id))
   356  					delete(sws.fragment, mvcc.WatchID(id))
   357  					sws.mu.Unlock()
   358  				}
   359  			}
   360  		case *pb.WatchRequest_ProgressRequest:
   361  			if uv.ProgressRequest != nil {
   362  				sws.mu.Lock()
   363  				sws.watchStream.RequestProgressAll()
   364  				sws.mu.Unlock()
   365  			}
   366  		default:
   367  			// we probably should not shutdown the entire stream when
   368  			// receive an valid command.
   369  			// so just do nothing instead.
   370  			continue
   371  		}
   372  	}
   373  }
   374  
   375  func (sws *serverWatchStream) sendLoop() {
   376  	// watch ids that are currently active
   377  	ids := make(map[mvcc.WatchID]struct{})
   378  	// watch responses pending on a watch id creation message
   379  	pending := make(map[mvcc.WatchID][]*pb.WatchResponse)
   380  
   381  	interval := GetProgressReportInterval()
   382  	progressTicker := time.NewTicker(interval)
   383  
   384  	defer func() {
   385  		progressTicker.Stop()
   386  		// drain the chan to clean up pending events
   387  		for ws := range sws.watchStream.Chan() {
   388  			mvcc.ReportEventReceived(len(ws.Events))
   389  		}
   390  		for _, wrs := range pending {
   391  			for _, ws := range wrs {
   392  				mvcc.ReportEventReceived(len(ws.Events))
   393  			}
   394  		}
   395  	}()
   396  
   397  	for {
   398  		select {
   399  		case wresp, ok := <-sws.watchStream.Chan():
   400  			if !ok {
   401  				return
   402  			}
   403  
   404  			// TODO: evs is []mvccpb.Event type
   405  			// either return []*mvccpb.Event from the mvcc package
   406  			// or define protocol buffer with []mvccpb.Event.
   407  			evs := wresp.Events
   408  			events := make([]*mvccpb.Event, len(evs))
   409  			sws.mu.RLock()
   410  			needPrevKV := sws.prevKV[wresp.WatchID]
   411  			sws.mu.RUnlock()
   412  			for i := range evs {
   413  				events[i] = &evs[i]
   414  				if needPrevKV && !IsCreateEvent(evs[i]) {
   415  					opt := mvcc.RangeOptions{Rev: evs[i].Kv.ModRevision - 1}
   416  					r, err := sws.watchable.Range(context.TODO(), evs[i].Kv.Key, nil, opt)
   417  					if err == nil && len(r.KVs) != 0 {
   418  						events[i].PrevKv = &(r.KVs[0])
   419  					}
   420  				}
   421  			}
   422  
   423  			canceled := wresp.CompactRevision != 0
   424  			wr := &pb.WatchResponse{
   425  				Header:          sws.newResponseHeader(wresp.Revision),
   426  				WatchId:         int64(wresp.WatchID),
   427  				Events:          events,
   428  				CompactRevision: wresp.CompactRevision,
   429  				Canceled:        canceled,
   430  			}
   431  
   432  			// Progress notifications can have WatchID -1
   433  			// if they announce on behalf of multiple watchers
   434  			if wresp.WatchID != clientv3.InvalidWatchID {
   435  				if _, okID := ids[wresp.WatchID]; !okID {
   436  					// buffer if id not yet announced
   437  					wrs := append(pending[wresp.WatchID], wr)
   438  					pending[wresp.WatchID] = wrs
   439  					continue
   440  				}
   441  			}
   442  
   443  			mvcc.ReportEventReceived(len(evs))
   444  
   445  			sws.mu.RLock()
   446  			fragmented, ok := sws.fragment[wresp.WatchID]
   447  			sws.mu.RUnlock()
   448  
   449  			var serr error
   450  			if !fragmented && !ok {
   451  				serr = sws.gRPCStream.Send(wr)
   452  			} else {
   453  				serr = sendFragments(wr, sws.maxRequestBytes, sws.gRPCStream.Send)
   454  			}
   455  
   456  			if serr != nil {
   457  				if isClientCtxErr(sws.gRPCStream.Context().Err(), serr) {
   458  					sws.lg.Debug("failed to send watch response to gRPC stream", zap.Error(serr))
   459  				} else {
   460  					sws.lg.Warn("failed to send watch response to gRPC stream", zap.Error(serr))
   461  					streamFailures.WithLabelValues("send", "watch").Inc()
   462  				}
   463  				return
   464  			}
   465  
   466  			sws.mu.Lock()
   467  			if len(evs) > 0 && sws.progress[wresp.WatchID] {
   468  				// elide next progress update if sent a key update
   469  				sws.progress[wresp.WatchID] = false
   470  			}
   471  			sws.mu.Unlock()
   472  
   473  		case c, ok := <-sws.ctrlStream:
   474  			if !ok {
   475  				return
   476  			}
   477  
   478  			if err := sws.gRPCStream.Send(c); err != nil {
   479  				if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
   480  					sws.lg.Debug("failed to send watch control response to gRPC stream", zap.Error(err))
   481  				} else {
   482  					sws.lg.Warn("failed to send watch control response to gRPC stream", zap.Error(err))
   483  					streamFailures.WithLabelValues("send", "watch").Inc()
   484  				}
   485  				return
   486  			}
   487  
   488  			// track id creation
   489  			wid := mvcc.WatchID(c.WatchId)
   490  
   491  			if !(!(c.Canceled && c.Created) || wid == clientv3.InvalidWatchID) {
   492  				panic(fmt.Sprintf("unexpected watchId: %d, wanted: %d, since both 'Canceled' and 'Created' are true", wid, clientv3.InvalidWatchID))
   493  			}
   494  
   495  			if c.Canceled && wid != clientv3.InvalidWatchID {
   496  				delete(ids, wid)
   497  				continue
   498  			}
   499  			if c.Created {
   500  				// flush buffered events
   501  				ids[wid] = struct{}{}
   502  				for _, v := range pending[wid] {
   503  					mvcc.ReportEventReceived(len(v.Events))
   504  					if err := sws.gRPCStream.Send(v); err != nil {
   505  						if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
   506  							sws.lg.Debug("failed to send pending watch response to gRPC stream", zap.Error(err))
   507  						} else {
   508  							sws.lg.Warn("failed to send pending watch response to gRPC stream", zap.Error(err))
   509  							streamFailures.WithLabelValues("send", "watch").Inc()
   510  						}
   511  						return
   512  					}
   513  				}
   514  				delete(pending, wid)
   515  			}
   516  
   517  		case <-progressTicker.C:
   518  			sws.mu.Lock()
   519  			for id, ok := range sws.progress {
   520  				if ok {
   521  					sws.watchStream.RequestProgress(id)
   522  				}
   523  				sws.progress[id] = true
   524  			}
   525  			sws.mu.Unlock()
   526  
   527  		case <-sws.closec:
   528  			return
   529  		}
   530  	}
   531  }
   532  
   533  func IsCreateEvent(e mvccpb.Event) bool {
   534  	return e.Type == mvccpb.PUT && e.Kv.CreateRevision == e.Kv.ModRevision
   535  }
   536  
   537  func sendFragments(
   538  	wr *pb.WatchResponse,
   539  	maxRequestBytes int,
   540  	sendFunc func(*pb.WatchResponse) error) error {
   541  	// no need to fragment if total request size is smaller
   542  	// than max request limit or response contains only one event
   543  	if wr.Size() < maxRequestBytes || len(wr.Events) < 2 {
   544  		return sendFunc(wr)
   545  	}
   546  
   547  	ow := *wr
   548  	ow.Events = make([]*mvccpb.Event, 0)
   549  	ow.Fragment = true
   550  
   551  	var idx int
   552  	for {
   553  		cur := ow
   554  		for _, ev := range wr.Events[idx:] {
   555  			cur.Events = append(cur.Events, ev)
   556  			if len(cur.Events) > 1 && cur.Size() >= maxRequestBytes {
   557  				cur.Events = cur.Events[:len(cur.Events)-1]
   558  				break
   559  			}
   560  			idx++
   561  		}
   562  		if idx == len(wr.Events) {
   563  			// last response has no more fragment
   564  			cur.Fragment = false
   565  		}
   566  		if err := sendFunc(&cur); err != nil {
   567  			return err
   568  		}
   569  		if !cur.Fragment {
   570  			break
   571  		}
   572  	}
   573  	return nil
   574  }
   575  
   576  func (sws *serverWatchStream) close() {
   577  	sws.watchStream.Close()
   578  	close(sws.closec)
   579  	sws.wg.Wait()
   580  }
   581  
   582  func (sws *serverWatchStream) newResponseHeader(rev int64) *pb.ResponseHeader {
   583  	return &pb.ResponseHeader{
   584  		ClusterId: uint64(sws.clusterID),
   585  		MemberId:  uint64(sws.memberID),
   586  		Revision:  rev,
   587  		RaftTerm:  sws.sg.Term(),
   588  	}
   589  }
   590  
   591  func filterNoDelete(e mvccpb.Event) bool {
   592  	return e.Type == mvccpb.DELETE
   593  }
   594  
   595  func filterNoPut(e mvccpb.Event) bool {
   596  	return e.Type == mvccpb.PUT
   597  }
   598  
   599  // FiltersFromRequest returns "mvcc.FilterFunc" from a given watch create request.
   600  func FiltersFromRequest(creq *pb.WatchCreateRequest) []mvcc.FilterFunc {
   601  	filters := make([]mvcc.FilterFunc, 0, len(creq.Filters))
   602  	for _, ft := range creq.Filters {
   603  		switch ft {
   604  		case pb.WatchCreateRequest_NOPUT:
   605  			filters = append(filters, filterNoPut)
   606  		case pb.WatchCreateRequest_NODELETE:
   607  			filters = append(filters, filterNoDelete)
   608  		default:
   609  		}
   610  	}
   611  	return filters
   612  }
   613  

View as plain text