1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package rpcreplay
16
17 import (
18 "bufio"
19 "bytes"
20 "context"
21 "encoding/binary"
22 "errors"
23 "fmt"
24 "io"
25 "log"
26 "net"
27 "os"
28 "sync"
29
30 pb "cloud.google.com/go/rpcreplay/proto/rpcreplay"
31 spb "google.golang.org/genproto/googleapis/rpc/status"
32 "google.golang.org/grpc"
33 "google.golang.org/grpc/metadata"
34 "google.golang.org/grpc/status"
35 "google.golang.org/protobuf/encoding/prototext"
36 "google.golang.org/protobuf/proto"
37 "google.golang.org/protobuf/types/known/anypb"
38 )
39
40
41 type Recorder struct {
42 mu sync.Mutex
43 w *bufio.Writer
44 f *os.File
45 next int
46 err error
47
48
49
50
51
52
53
54 BeforeFunc func(string, proto.Message) error
55 }
56
57
58
59
60
61 func NewRecorder(filename string, initial []byte) (*Recorder, error) {
62 f, err := os.Create(filename)
63 if err != nil {
64 return nil, err
65 }
66 rec, err := NewRecorderWriter(f, initial)
67 if err != nil {
68 _ = f.Close()
69 return nil, err
70 }
71 rec.f = f
72 return rec, nil
73 }
74
75
76
77
78
79 func NewRecorderWriter(w io.Writer, initial []byte) (*Recorder, error) {
80 bw := bufio.NewWriter(w)
81 if err := writeHeader(bw, initial); err != nil {
82 return nil, err
83 }
84 return &Recorder{w: bw, next: 1}, nil
85 }
86
87
88
89 func (r *Recorder) DialOptions() []grpc.DialOption {
90 return []grpc.DialOption{
91 grpc.WithUnaryInterceptor(r.interceptUnary),
92 grpc.WithStreamInterceptor(r.interceptStream),
93 }
94 }
95
96
97 func (r *Recorder) Close() error {
98 r.mu.Lock()
99 defer r.mu.Unlock()
100 if r.err != nil {
101 return r.err
102 }
103 err := r.w.Flush()
104 if r.f != nil {
105 if err2 := r.f.Close(); err == nil {
106 err = err2
107 }
108 }
109 return err
110 }
111
112
113 func (r *Recorder) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
114 ereq := &entry{
115 kind: pb.Entry_REQUEST,
116 method: method,
117 msg: message{msg: proto.Clone(req.(proto.Message))},
118 }
119
120 if r.BeforeFunc != nil {
121 if err := r.BeforeFunc(method, ereq.msg.msg); err != nil {
122 return err
123 }
124 }
125 refIndex, err := r.writeEntry(ereq)
126 if err != nil {
127 return err
128 }
129 ierr := invoker(ctx, method, req, res, cc, opts...)
130 eres := &entry{
131 kind: pb.Entry_RESPONSE,
132 refIndex: refIndex,
133 }
134
135
136
137
138 if _, ok := status.FromError(ierr); !ok {
139 r.mu.Lock()
140 r.err = fmt.Errorf("saw non-status error in %s response: %w (%T)", method, ierr, ierr)
141 r.mu.Unlock()
142 return ierr
143 }
144 eres.msg.set(proto.Clone(res.(proto.Message)), ierr)
145 if r.BeforeFunc != nil {
146 if err := r.BeforeFunc(method, eres.msg.msg); err != nil {
147 return err
148 }
149 }
150 if _, err := r.writeEntry(eres); err != nil {
151 return err
152 }
153 return ierr
154 }
155
156 func (r *Recorder) writeEntry(e *entry) (int, error) {
157 r.mu.Lock()
158 defer r.mu.Unlock()
159 if r.err != nil {
160 return 0, r.err
161 }
162 err := writeEntry(r.w, e)
163 if err != nil {
164 r.err = err
165 return 0, err
166 }
167 n := r.next
168 r.next++
169 return n, nil
170 }
171
172 func (r *Recorder) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
173 cstream, serr := streamer(ctx, desc, cc, method, opts...)
174 e := &entry{
175 kind: pb.Entry_CREATE_STREAM,
176 method: method,
177 }
178 e.msg.set(nil, serr)
179 refIndex, err := r.writeEntry(e)
180 if err != nil {
181 return nil, err
182 }
183 return &recClientStream{
184 ctx: ctx,
185 rec: r,
186 cstream: cstream,
187 refIndex: refIndex,
188 }, serr
189 }
190
191
192
193
194 type recClientStream struct {
195 ctx context.Context
196 rec *Recorder
197 cstream grpc.ClientStream
198 refIndex int
199 }
200
201 func (rcs *recClientStream) Context() context.Context { return rcs.ctx }
202
203 func (rcs *recClientStream) SendMsg(m interface{}) error {
204 serr := rcs.cstream.SendMsg(m)
205 e := &entry{
206 kind: pb.Entry_SEND,
207 refIndex: rcs.refIndex,
208 }
209 e.msg.set(m, serr)
210 if _, err := rcs.rec.writeEntry(e); err != nil {
211 return err
212 }
213 return serr
214 }
215
216 func (rcs *recClientStream) RecvMsg(m interface{}) error {
217 serr := rcs.cstream.RecvMsg(m)
218 e := &entry{
219 kind: pb.Entry_RECV,
220 refIndex: rcs.refIndex,
221 }
222 e.msg.set(m, serr)
223 if _, err := rcs.rec.writeEntry(e); err != nil {
224 return err
225 }
226 return serr
227 }
228
229 func (rcs *recClientStream) Header() (metadata.MD, error) {
230
231 return rcs.cstream.Header()
232 }
233
234 func (rcs *recClientStream) Trailer() metadata.MD {
235
236 return rcs.cstream.Trailer()
237 }
238
239 func (rcs *recClientStream) CloseSend() error {
240
241 return rcs.cstream.CloseSend()
242 }
243
244
245 type Replayer struct {
246 initial []byte
247 log func(format string, v ...interface{})
248
249 mu sync.Mutex
250 calls []*call
251 streams []*stream
252
253
254
255
256
257 BeforeFunc func(string, proto.Message) error
258 }
259
260
261 type call struct {
262 method string
263 request proto.Message
264 response message
265 }
266
267
268
269 type stream struct {
270 method string
271 createIndex int
272 createErr error
273 sends []message
274 recvs []message
275 }
276
277
278 func NewReplayer(filename string) (*Replayer, error) {
279 f, err := os.Open(filename)
280 if err != nil {
281 return nil, err
282 }
283 defer f.Close()
284 return NewReplayerReader(f)
285 }
286
287
288 func NewReplayerReader(r io.Reader) (*Replayer, error) {
289 rep := &Replayer{
290 log: func(string, ...interface{}) {},
291 }
292 if err := rep.read(r); err != nil {
293 return nil, err
294 }
295 return rep, nil
296 }
297
298
299
300
301 func (rep *Replayer) read(r io.Reader) error {
302 r = bufio.NewReader(r)
303 bytes, err := readHeader(r)
304 if err != nil {
305 return err
306 }
307 rep.initial = bytes
308
309 callsByIndex := map[int]*call{}
310 streamsByIndex := map[int]*stream{}
311 for i := 1; ; i++ {
312 e, err := readEntry(r)
313 if err != nil {
314 return err
315 }
316 if e == nil {
317 break
318 }
319 switch e.kind {
320 case pb.Entry_REQUEST:
321 callsByIndex[i] = &call{
322 method: e.method,
323 request: e.msg.msg,
324 }
325
326 case pb.Entry_RESPONSE:
327 call := callsByIndex[e.refIndex]
328 if call == nil {
329 return fmt.Errorf("replayer: no request for response #%d", i)
330 }
331 delete(callsByIndex, e.refIndex)
332 call.response = e.msg
333 rep.calls = append(rep.calls, call)
334
335 case pb.Entry_CREATE_STREAM:
336 s := &stream{method: e.method, createIndex: i}
337 s.createErr = e.msg.err
338 streamsByIndex[i] = s
339 rep.streams = append(rep.streams, s)
340
341 case pb.Entry_SEND:
342 s := streamsByIndex[e.refIndex]
343 if s == nil {
344 return fmt.Errorf("replayer: no stream for send #%d", i)
345 }
346 s.sends = append(s.sends, e.msg)
347
348 case pb.Entry_RECV:
349 s := streamsByIndex[e.refIndex]
350 if s == nil {
351 return fmt.Errorf("replayer: no stream for recv #%d", i)
352 }
353 s.recvs = append(s.recvs, e.msg)
354
355 default:
356 return fmt.Errorf("replayer: unknown kind %s", e.kind)
357 }
358 }
359 if len(callsByIndex) > 0 {
360 return fmt.Errorf("replayer: %d unmatched requests", len(callsByIndex))
361 }
362 return nil
363 }
364
365
366
367 func (rep *Replayer) DialOptions() []grpc.DialOption {
368 return []grpc.DialOption{
369
370
371
372 grpc.WithBlock(),
373 grpc.WithUnaryInterceptor(rep.interceptUnary),
374 grpc.WithStreamInterceptor(rep.interceptStream),
375 }
376 }
377
378
379 func (rep *Replayer) Connection() (*grpc.ClientConn, error) {
380
381
382
383 srv := grpc.NewServer()
384 l, err := net.Listen("tcp", "localhost:0")
385 if err != nil {
386 return nil, err
387 }
388 go func() {
389 if err := srv.Serve(l); err != nil {
390 panic(err)
391 }
392 }()
393 conn, err := grpc.Dial(l.Addr().String(),
394 append([]grpc.DialOption{grpc.WithInsecure()}, rep.DialOptions()...)...)
395 if err != nil {
396 return nil, err
397 }
398 conn.Close()
399 srv.Stop()
400 return conn, nil
401 }
402
403
404 func (rep *Replayer) Initial() []byte { return rep.initial }
405
406
407
408 func (rep *Replayer) SetLogFunc(f func(format string, v ...interface{})) {
409 rep.log = f
410 }
411
412
413 func (rep *Replayer) Close() error {
414 return nil
415 }
416
417 func (rep *Replayer) interceptUnary(_ context.Context, method string, req, res interface{}, _ *grpc.ClientConn, _ grpc.UnaryInvoker, _ ...grpc.CallOption) error {
418 mreq := req.(proto.Message)
419 if rep.BeforeFunc != nil {
420 if err := rep.BeforeFunc(method, mreq); err != nil {
421 return err
422 }
423 }
424 rep.log("request %s (%s)", method, req)
425 call := rep.extractCall(method, mreq)
426 if call == nil {
427 return fmt.Errorf("replayer: request not found: %s", mreq)
428 }
429 rep.log("returning %v", call.response)
430 if call.response.err != nil {
431 return call.response.err
432 }
433 proto.Merge(res.(proto.Message), call.response.msg)
434 return nil
435 }
436
437 func (rep *Replayer) interceptStream(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, method string, _ grpc.Streamer, _ ...grpc.CallOption) (grpc.ClientStream, error) {
438 rep.log("create-stream %s", method)
439 return &repClientStream{ctx: ctx, rep: rep, method: method}, nil
440 }
441
442 type repClientStream struct {
443 ctx context.Context
444 rep *Replayer
445 method string
446 str *stream
447 }
448
449 func (rcs *repClientStream) Context() context.Context { return rcs.ctx }
450
451 func (rcs *repClientStream) SendMsg(req interface{}) error {
452 if rcs.str == nil {
453 if err := rcs.setStream(rcs.method, req.(proto.Message)); err != nil {
454 return err
455 }
456 }
457 if len(rcs.str.sends) == 0 {
458 return fmt.Errorf("replayer: no more sends for stream %s, created at index %d",
459 rcs.str.method, rcs.str.createIndex)
460 }
461
462 msg := rcs.str.sends[0]
463 rcs.str.sends = rcs.str.sends[1:]
464 return msg.err
465 }
466
467 func (rcs *repClientStream) setStream(method string, req proto.Message) error {
468 str := rcs.rep.extractStream(method, req)
469 if str == nil {
470 return fmt.Errorf("replayer: stream not found for method %s and request %v", method, req)
471 }
472 if str.createErr != nil {
473 return str.createErr
474 }
475 rcs.str = str
476 return nil
477 }
478
479 func (rcs *repClientStream) RecvMsg(m interface{}) error {
480 if rcs.str == nil {
481
482 if err := rcs.setStream(rcs.method, nil); err != nil {
483 return err
484 }
485 }
486 if len(rcs.str.recvs) == 0 {
487 return fmt.Errorf("replayer: no more receives for stream %s, created at index %d",
488 rcs.str.method, rcs.str.createIndex)
489 }
490 msg := rcs.str.recvs[0]
491 rcs.str.recvs = rcs.str.recvs[1:]
492 if msg.err != nil {
493 return msg.err
494 }
495 proto.Merge(m.(proto.Message), msg.msg)
496 return nil
497 }
498
499 func (rcs *repClientStream) Header() (metadata.MD, error) {
500 log.Printf("replay: stream metadata not supported")
501 return nil, nil
502 }
503
504 func (rcs *repClientStream) Trailer() metadata.MD {
505 log.Printf("replay: stream metadata not supported")
506 return nil
507 }
508
509 func (rcs *repClientStream) CloseSend() error {
510 return nil
511 }
512
513
514
515 func (rep *Replayer) extractCall(method string, req proto.Message) *call {
516 rep.mu.Lock()
517 defer rep.mu.Unlock()
518 for i, call := range rep.calls {
519 if call == nil {
520 continue
521 }
522 if method == call.method && proto.Equal(req, call.request) {
523 rep.calls[i] = nil
524 return call
525 }
526 }
527 return nil
528 }
529
530
531
532
533 func (rep *Replayer) extractStream(method string, req proto.Message) *stream {
534 rep.mu.Lock()
535 defer rep.mu.Unlock()
536 for i, stream := range rep.streams {
537
538 if stream == nil || stream.method != method {
539 continue
540 }
541
542
543 if req != nil && len(stream.sends) > 0 && !proto.Equal(req, stream.sends[0].msg) {
544 continue
545 }
546 rep.streams[i] = nil
547 return stream
548 }
549 return nil
550 }
551
552
553
554 func Fprint(w io.Writer, filename string) error {
555 f, err := os.Open(filename)
556 if err != nil {
557 return err
558 }
559 defer f.Close()
560 return FprintReader(w, f)
561 }
562
563
564
565 func FprintReader(w io.Writer, r io.Reader) error {
566 initial, err := readHeader(r)
567 if err != nil {
568 return err
569 }
570 fmt.Fprintf(w, "initial state: %q\n", string(initial))
571 for i := 1; ; i++ {
572 e, err := readEntry(r)
573 if err != nil {
574 return err
575 }
576 if e == nil {
577 return nil
578 }
579
580 fmt.Fprintf(w, "#%d: kind: %s, method: %s, ref index: %d", i, e.kind, e.method, e.refIndex)
581 switch {
582 case e.msg.msg != nil:
583 fmt.Fprintf(w, ", message:\n")
584 b, err := prototext.Marshal(e.msg.msg)
585 if err != nil {
586 return err
587 }
588 if _, err := io.Copy(w, bytes.NewReader(b)); err != nil {
589 return err
590 }
591 case e.msg.err != nil:
592 fmt.Fprintf(w, ", error: %v\n", e.msg.err)
593 default:
594 fmt.Fprintln(w)
595 }
596 }
597 }
598
599
600 type entry struct {
601 kind pb.Entry_Kind
602 method string
603 msg message
604 refIndex int
605 }
606
607 func (e1 *entry) equal(e2 *entry) bool {
608 if e1 == nil && e2 == nil {
609 return true
610 }
611 if e1 == nil || e2 == nil {
612 return false
613 }
614 return e1.kind == e2.kind &&
615 e1.method == e2.method &&
616 proto.Equal(e1.msg.msg, e2.msg.msg) &&
617 errEqual(e1.msg.err, e2.msg.err) &&
618 e1.refIndex == e2.refIndex
619 }
620
621 func errEqual(e1, e2 error) bool {
622 if e1 == e2 {
623 return true
624 }
625 s1, ok1 := status.FromError(e1)
626 s2, ok2 := status.FromError(e2)
627 if !ok1 || !ok2 {
628 return false
629 }
630 return proto.Equal(s1.Proto(), s2.Proto())
631 }
632
633
634 type message struct {
635 msg proto.Message
636 err error
637 }
638
639 func (m *message) set(msg interface{}, err error) {
640 m.err = err
641 if err != io.EOF && msg != nil {
642 m.msg = msg.(proto.Message)
643 }
644 }
645
646
647
648
649
650
651
652
653
654 const magic = "RPCReplay"
655
656 func writeHeader(w io.Writer, initial []byte) error {
657 if _, err := io.WriteString(w, magic); err != nil {
658 return err
659 }
660 return writeRecord(w, initial)
661 }
662
663 func readHeader(r io.Reader) ([]byte, error) {
664 var buf [len(magic)]byte
665 if _, err := io.ReadFull(r, buf[:]); err != nil {
666 if err == io.EOF {
667 err = errors.New("rpcreplay: empty replay file")
668 }
669 return nil, err
670 }
671 if string(buf[:]) != magic {
672 return nil, errors.New("rpcreplay: not a replay file (does not begin with magic string)")
673 }
674 bytes, err := readRecord(r)
675 if err == io.EOF {
676 err = errors.New("rpcreplay: missing initial state")
677 }
678 return bytes, err
679 }
680
681 func writeEntry(w io.Writer, e *entry) error {
682 var m proto.Message
683 if e.msg.err != nil && e.msg.err != io.EOF {
684 s, ok := status.FromError(e.msg.err)
685 if !ok {
686 return fmt.Errorf("rpcreplay: error %w is not a Status", e.msg.err)
687 }
688 m = s.Proto()
689 } else {
690 m = e.msg.msg
691 }
692 var a *anypb.Any
693 var err error
694 if m != nil {
695 a, err = anypb.New(m)
696 if err != nil {
697 return err
698 }
699 }
700 pe := &pb.Entry{
701 Kind: e.kind,
702 Method: e.method,
703 Message: a,
704 IsError: e.msg.err != nil,
705 RefIndex: int32(e.refIndex),
706 }
707 bytes, err := proto.Marshal(pe)
708 if err != nil {
709 return err
710 }
711 return writeRecord(w, bytes)
712 }
713
714 func readEntry(r io.Reader) (*entry, error) {
715 buf, err := readRecord(r)
716 if err == io.EOF {
717 return nil, nil
718 }
719 if err != nil {
720 return nil, err
721 }
722 var pe pb.Entry
723 if err := proto.Unmarshal(buf, &pe); err != nil {
724 return nil, err
725 }
726 var msg message
727 if pe.Message != nil {
728 if pe.IsError {
729 s := &spb.Status{}
730 err := anypb.UnmarshalTo(pe.Message, s, proto.UnmarshalOptions{AllowPartial: true, DiscardUnknown: true})
731 if err != nil {
732 return nil, err
733 }
734 msg.err = status.ErrorProto(s)
735 } else {
736 m, err := anypb.UnmarshalNew(pe.Message, proto.UnmarshalOptions{AllowPartial: true, DiscardUnknown: true})
737 if err != nil {
738 return nil, err
739 }
740 msg.msg = m
741 }
742 } else if pe.IsError {
743 msg.err = io.EOF
744 } else if pe.Kind != pb.Entry_CREATE_STREAM {
745 return nil, errors.New("rpcreplay: entry with nil message and false is_error")
746 }
747 return &entry{
748 kind: pe.Kind,
749 method: pe.Method,
750 msg: msg,
751 refIndex: int(pe.RefIndex),
752 }, nil
753 }
754
755
756
757
758 func writeRecord(w io.Writer, data []byte) error {
759 if err := binary.Write(w, binary.LittleEndian, uint32(len(data))); err != nil {
760 return err
761 }
762 _, err := w.Write(data)
763 return err
764 }
765
766 func readRecord(r io.Reader) ([]byte, error) {
767 var size uint32
768 if err := binary.Read(r, binary.LittleEndian, &size); err != nil {
769 return nil, err
770 }
771 buf := make([]byte, size)
772 if _, err := io.ReadFull(r, buf); err != nil {
773 return nil, err
774 }
775 return buf, nil
776 }
777
View as plain text