1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package v3rpc
16
17 import (
18 "context"
19 "fmt"
20 "io"
21 "math/rand"
22 "sync"
23 "time"
24
25 pb "go.etcd.io/etcd/api/v3/etcdserverpb"
26 "go.etcd.io/etcd/api/v3/mvccpb"
27 "go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
28 clientv3 "go.etcd.io/etcd/client/v3"
29 "go.etcd.io/etcd/server/v3/auth"
30 "go.etcd.io/etcd/server/v3/etcdserver"
31 "go.etcd.io/etcd/server/v3/mvcc"
32
33 "go.uber.org/zap"
34 )
35
36 const minWatchProgressInterval = 100 * time.Millisecond
37
38 type watchServer struct {
39 lg *zap.Logger
40
41 clusterID int64
42 memberID int64
43
44 maxRequestBytes int
45
46 sg etcdserver.RaftStatusGetter
47 watchable mvcc.WatchableKV
48 ag AuthGetter
49 }
50
51
52 func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer {
53 srv := &watchServer{
54 lg: s.Cfg.Logger,
55
56 clusterID: int64(s.Cluster().ID()),
57 memberID: int64(s.ID()),
58
59 maxRequestBytes: int(s.Cfg.MaxRequestBytes + grpcOverheadBytes),
60
61 sg: s,
62 watchable: s.Watchable(),
63 ag: s,
64 }
65 if srv.lg == nil {
66 srv.lg = zap.NewNop()
67 }
68 if s.Cfg.WatchProgressNotifyInterval > 0 {
69 if s.Cfg.WatchProgressNotifyInterval < minWatchProgressInterval {
70 srv.lg.Warn(
71 "adjusting watch progress notify interval to minimum period",
72 zap.Duration("min-watch-progress-notify-interval", minWatchProgressInterval),
73 )
74 s.Cfg.WatchProgressNotifyInterval = minWatchProgressInterval
75 }
76 SetProgressReportInterval(s.Cfg.WatchProgressNotifyInterval)
77 }
78 return srv
79 }
80
81 var (
82
83
84
85 progressReportInterval = 10 * time.Minute
86 progressReportIntervalMu sync.RWMutex
87 )
88
89
90 func GetProgressReportInterval() time.Duration {
91 progressReportIntervalMu.RLock()
92 interval := progressReportInterval
93 progressReportIntervalMu.RUnlock()
94
95
96
97
98 jitter := time.Duration(rand.Int63n(int64(interval) / 10))
99
100 return interval + jitter
101 }
102
103
104 func SetProgressReportInterval(newTimeout time.Duration) {
105 progressReportIntervalMu.Lock()
106 progressReportInterval = newTimeout
107 progressReportIntervalMu.Unlock()
108 }
109
110
111
112
113
114
115 const ctrlStreamBufLen = 16
116
117
118
119
120
121 type serverWatchStream struct {
122 lg *zap.Logger
123
124 clusterID int64
125 memberID int64
126
127 maxRequestBytes int
128
129 sg etcdserver.RaftStatusGetter
130 watchable mvcc.WatchableKV
131 ag AuthGetter
132
133 gRPCStream pb.Watch_WatchServer
134 watchStream mvcc.WatchStream
135 ctrlStream chan *pb.WatchResponse
136
137
138 mu sync.RWMutex
139
140
141 progress map[mvcc.WatchID]bool
142
143 prevKV map[mvcc.WatchID]bool
144
145 fragment map[mvcc.WatchID]bool
146
147
148 closec chan struct{}
149
150
151 wg sync.WaitGroup
152 }
153
154 func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
155 sws := serverWatchStream{
156 lg: ws.lg,
157
158 clusterID: ws.clusterID,
159 memberID: ws.memberID,
160
161 maxRequestBytes: ws.maxRequestBytes,
162
163 sg: ws.sg,
164 watchable: ws.watchable,
165 ag: ws.ag,
166
167 gRPCStream: stream,
168 watchStream: ws.watchable.NewWatchStream(),
169
170 ctrlStream: make(chan *pb.WatchResponse, ctrlStreamBufLen),
171
172 progress: make(map[mvcc.WatchID]bool),
173 prevKV: make(map[mvcc.WatchID]bool),
174 fragment: make(map[mvcc.WatchID]bool),
175
176 closec: make(chan struct{}),
177 }
178
179 sws.wg.Add(1)
180 go func() {
181 sws.sendLoop()
182 sws.wg.Done()
183 }()
184
185 errc := make(chan error, 1)
186
187
188
189
190 go func() {
191 if rerr := sws.recvLoop(); rerr != nil {
192 if isClientCtxErr(stream.Context().Err(), rerr) {
193 sws.lg.Debug("failed to receive watch request from gRPC stream", zap.Error(rerr))
194 } else {
195 sws.lg.Warn("failed to receive watch request from gRPC stream", zap.Error(rerr))
196 streamFailures.WithLabelValues("receive", "watch").Inc()
197 }
198 errc <- rerr
199 }
200 }()
201
202
203
204
205
206
207
208
209
210
211 select {
212 case err = <-errc:
213 if err == context.Canceled {
214 err = rpctypes.ErrGRPCWatchCanceled
215 }
216 close(sws.ctrlStream)
217 case <-stream.Context().Done():
218 err = stream.Context().Err()
219 if err == context.Canceled {
220 err = rpctypes.ErrGRPCWatchCanceled
221 }
222 }
223
224 sws.close()
225 return err
226 }
227
228 func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) error {
229 authInfo, err := sws.ag.AuthInfoFromCtx(sws.gRPCStream.Context())
230 if err != nil {
231 return err
232 }
233 if authInfo == nil {
234
235 authInfo = &auth.AuthInfo{}
236 }
237 return sws.ag.AuthStore().IsRangePermitted(authInfo, wcr.Key, wcr.RangeEnd)
238 }
239
240 func (sws *serverWatchStream) recvLoop() error {
241 for {
242 req, err := sws.gRPCStream.Recv()
243 if err == io.EOF {
244 return nil
245 }
246 if err != nil {
247 return err
248 }
249
250 switch uv := req.RequestUnion.(type) {
251 case *pb.WatchRequest_CreateRequest:
252 if uv.CreateRequest == nil {
253 break
254 }
255
256 creq := uv.CreateRequest
257 if len(creq.Key) == 0 {
258
259 creq.Key = []byte{0}
260 }
261 if len(creq.RangeEnd) == 0 {
262
263
264 creq.RangeEnd = nil
265 }
266 if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
267
268 creq.RangeEnd = []byte{}
269 }
270
271 err := sws.isWatchPermitted(creq)
272 if err != nil {
273 var cancelReason string
274 switch err {
275 case auth.ErrInvalidAuthToken:
276 cancelReason = rpctypes.ErrGRPCInvalidAuthToken.Error()
277 case auth.ErrAuthOldRevision:
278 cancelReason = rpctypes.ErrGRPCAuthOldRevision.Error()
279 case auth.ErrUserEmpty:
280 cancelReason = rpctypes.ErrGRPCUserEmpty.Error()
281 default:
282 if err != auth.ErrPermissionDenied {
283 sws.lg.Error("unexpected error code", zap.Error(err))
284 }
285 cancelReason = rpctypes.ErrGRPCPermissionDenied.Error()
286 }
287
288 wr := &pb.WatchResponse{
289 Header: sws.newResponseHeader(sws.watchStream.Rev()),
290 WatchId: clientv3.InvalidWatchID,
291 Canceled: true,
292 Created: true,
293 CancelReason: cancelReason,
294 }
295
296 select {
297 case sws.ctrlStream <- wr:
298 continue
299 case <-sws.closec:
300 return nil
301 }
302 }
303
304 filters := FiltersFromRequest(creq)
305
306 wsrev := sws.watchStream.Rev()
307 rev := creq.StartRevision
308 if rev == 0 {
309 rev = wsrev + 1
310 }
311 id, err := sws.watchStream.Watch(mvcc.WatchID(creq.WatchId), creq.Key, creq.RangeEnd, rev, filters...)
312 if err == nil {
313 sws.mu.Lock()
314 if creq.ProgressNotify {
315 sws.progress[id] = true
316 }
317 if creq.PrevKv {
318 sws.prevKV[id] = true
319 }
320 if creq.Fragment {
321 sws.fragment[id] = true
322 }
323 sws.mu.Unlock()
324 } else {
325 id = clientv3.InvalidWatchID
326 }
327
328 wr := &pb.WatchResponse{
329 Header: sws.newResponseHeader(wsrev),
330 WatchId: int64(id),
331 Created: true,
332 Canceled: err != nil,
333 }
334 if err != nil {
335 wr.CancelReason = err.Error()
336 }
337 select {
338 case sws.ctrlStream <- wr:
339 case <-sws.closec:
340 return nil
341 }
342
343 case *pb.WatchRequest_CancelRequest:
344 if uv.CancelRequest != nil {
345 id := uv.CancelRequest.WatchId
346 err := sws.watchStream.Cancel(mvcc.WatchID(id))
347 if err == nil {
348 sws.ctrlStream <- &pb.WatchResponse{
349 Header: sws.newResponseHeader(sws.watchStream.Rev()),
350 WatchId: id,
351 Canceled: true,
352 }
353 sws.mu.Lock()
354 delete(sws.progress, mvcc.WatchID(id))
355 delete(sws.prevKV, mvcc.WatchID(id))
356 delete(sws.fragment, mvcc.WatchID(id))
357 sws.mu.Unlock()
358 }
359 }
360 case *pb.WatchRequest_ProgressRequest:
361 if uv.ProgressRequest != nil {
362 sws.mu.Lock()
363 sws.watchStream.RequestProgressAll()
364 sws.mu.Unlock()
365 }
366 default:
367
368
369
370 continue
371 }
372 }
373 }
374
375 func (sws *serverWatchStream) sendLoop() {
376
377 ids := make(map[mvcc.WatchID]struct{})
378
379 pending := make(map[mvcc.WatchID][]*pb.WatchResponse)
380
381 interval := GetProgressReportInterval()
382 progressTicker := time.NewTicker(interval)
383
384 defer func() {
385 progressTicker.Stop()
386
387 for ws := range sws.watchStream.Chan() {
388 mvcc.ReportEventReceived(len(ws.Events))
389 }
390 for _, wrs := range pending {
391 for _, ws := range wrs {
392 mvcc.ReportEventReceived(len(ws.Events))
393 }
394 }
395 }()
396
397 for {
398 select {
399 case wresp, ok := <-sws.watchStream.Chan():
400 if !ok {
401 return
402 }
403
404
405
406
407 evs := wresp.Events
408 events := make([]*mvccpb.Event, len(evs))
409 sws.mu.RLock()
410 needPrevKV := sws.prevKV[wresp.WatchID]
411 sws.mu.RUnlock()
412 for i := range evs {
413 events[i] = &evs[i]
414 if needPrevKV && !IsCreateEvent(evs[i]) {
415 opt := mvcc.RangeOptions{Rev: evs[i].Kv.ModRevision - 1}
416 r, err := sws.watchable.Range(context.TODO(), evs[i].Kv.Key, nil, opt)
417 if err == nil && len(r.KVs) != 0 {
418 events[i].PrevKv = &(r.KVs[0])
419 }
420 }
421 }
422
423 canceled := wresp.CompactRevision != 0
424 wr := &pb.WatchResponse{
425 Header: sws.newResponseHeader(wresp.Revision),
426 WatchId: int64(wresp.WatchID),
427 Events: events,
428 CompactRevision: wresp.CompactRevision,
429 Canceled: canceled,
430 }
431
432
433
434 if wresp.WatchID != clientv3.InvalidWatchID {
435 if _, okID := ids[wresp.WatchID]; !okID {
436
437 wrs := append(pending[wresp.WatchID], wr)
438 pending[wresp.WatchID] = wrs
439 continue
440 }
441 }
442
443 mvcc.ReportEventReceived(len(evs))
444
445 sws.mu.RLock()
446 fragmented, ok := sws.fragment[wresp.WatchID]
447 sws.mu.RUnlock()
448
449 var serr error
450 if !fragmented && !ok {
451 serr = sws.gRPCStream.Send(wr)
452 } else {
453 serr = sendFragments(wr, sws.maxRequestBytes, sws.gRPCStream.Send)
454 }
455
456 if serr != nil {
457 if isClientCtxErr(sws.gRPCStream.Context().Err(), serr) {
458 sws.lg.Debug("failed to send watch response to gRPC stream", zap.Error(serr))
459 } else {
460 sws.lg.Warn("failed to send watch response to gRPC stream", zap.Error(serr))
461 streamFailures.WithLabelValues("send", "watch").Inc()
462 }
463 return
464 }
465
466 sws.mu.Lock()
467 if len(evs) > 0 && sws.progress[wresp.WatchID] {
468
469 sws.progress[wresp.WatchID] = false
470 }
471 sws.mu.Unlock()
472
473 case c, ok := <-sws.ctrlStream:
474 if !ok {
475 return
476 }
477
478 if err := sws.gRPCStream.Send(c); err != nil {
479 if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
480 sws.lg.Debug("failed to send watch control response to gRPC stream", zap.Error(err))
481 } else {
482 sws.lg.Warn("failed to send watch control response to gRPC stream", zap.Error(err))
483 streamFailures.WithLabelValues("send", "watch").Inc()
484 }
485 return
486 }
487
488
489 wid := mvcc.WatchID(c.WatchId)
490
491 if !(!(c.Canceled && c.Created) || wid == clientv3.InvalidWatchID) {
492 panic(fmt.Sprintf("unexpected watchId: %d, wanted: %d, since both 'Canceled' and 'Created' are true", wid, clientv3.InvalidWatchID))
493 }
494
495 if c.Canceled && wid != clientv3.InvalidWatchID {
496 delete(ids, wid)
497 continue
498 }
499 if c.Created {
500
501 ids[wid] = struct{}{}
502 for _, v := range pending[wid] {
503 mvcc.ReportEventReceived(len(v.Events))
504 if err := sws.gRPCStream.Send(v); err != nil {
505 if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
506 sws.lg.Debug("failed to send pending watch response to gRPC stream", zap.Error(err))
507 } else {
508 sws.lg.Warn("failed to send pending watch response to gRPC stream", zap.Error(err))
509 streamFailures.WithLabelValues("send", "watch").Inc()
510 }
511 return
512 }
513 }
514 delete(pending, wid)
515 }
516
517 case <-progressTicker.C:
518 sws.mu.Lock()
519 for id, ok := range sws.progress {
520 if ok {
521 sws.watchStream.RequestProgress(id)
522 }
523 sws.progress[id] = true
524 }
525 sws.mu.Unlock()
526
527 case <-sws.closec:
528 return
529 }
530 }
531 }
532
533 func IsCreateEvent(e mvccpb.Event) bool {
534 return e.Type == mvccpb.PUT && e.Kv.CreateRevision == e.Kv.ModRevision
535 }
536
537 func sendFragments(
538 wr *pb.WatchResponse,
539 maxRequestBytes int,
540 sendFunc func(*pb.WatchResponse) error) error {
541
542
543 if wr.Size() < maxRequestBytes || len(wr.Events) < 2 {
544 return sendFunc(wr)
545 }
546
547 ow := *wr
548 ow.Events = make([]*mvccpb.Event, 0)
549 ow.Fragment = true
550
551 var idx int
552 for {
553 cur := ow
554 for _, ev := range wr.Events[idx:] {
555 cur.Events = append(cur.Events, ev)
556 if len(cur.Events) > 1 && cur.Size() >= maxRequestBytes {
557 cur.Events = cur.Events[:len(cur.Events)-1]
558 break
559 }
560 idx++
561 }
562 if idx == len(wr.Events) {
563
564 cur.Fragment = false
565 }
566 if err := sendFunc(&cur); err != nil {
567 return err
568 }
569 if !cur.Fragment {
570 break
571 }
572 }
573 return nil
574 }
575
576 func (sws *serverWatchStream) close() {
577 sws.watchStream.Close()
578 close(sws.closec)
579 sws.wg.Wait()
580 }
581
582 func (sws *serverWatchStream) newResponseHeader(rev int64) *pb.ResponseHeader {
583 return &pb.ResponseHeader{
584 ClusterId: uint64(sws.clusterID),
585 MemberId: uint64(sws.memberID),
586 Revision: rev,
587 RaftTerm: sws.sg.Term(),
588 }
589 }
590
591 func filterNoDelete(e mvccpb.Event) bool {
592 return e.Type == mvccpb.DELETE
593 }
594
595 func filterNoPut(e mvccpb.Event) bool {
596 return e.Type == mvccpb.PUT
597 }
598
599
600 func FiltersFromRequest(creq *pb.WatchCreateRequest) []mvcc.FilterFunc {
601 filters := make([]mvcc.FilterFunc, 0, len(creq.Filters))
602 for _, ft := range creq.Filters {
603 switch ft {
604 case pb.WatchCreateRequest_NOPUT:
605 filters = append(filters, filterNoPut)
606 case pb.WatchCreateRequest_NODELETE:
607 filters = append(filters, filterNoDelete)
608 default:
609 }
610 }
611 return filters
612 }
613
View as plain text