1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package rpcreplay
16
17 import (
18 "bytes"
19 "context"
20 "errors"
21 "io"
22 "strings"
23 "testing"
24
25 "cloud.google.com/go/internal/testutil"
26 ipb "cloud.google.com/go/rpcreplay/proto/intstore"
27 rpb "cloud.google.com/go/rpcreplay/proto/rpcreplay"
28 "github.com/google/go-cmp/cmp"
29 "github.com/google/go-cmp/cmp/cmpopts"
30 "google.golang.org/grpc"
31 "google.golang.org/grpc/codes"
32 "google.golang.org/grpc/status"
33 "google.golang.org/protobuf/proto"
34 "google.golang.org/protobuf/testing/protocmp"
35 )
36
37 func TestRecordIO(t *testing.T) {
38 buf := &bytes.Buffer{}
39 want := []byte{1, 2, 3}
40 if err := writeRecord(buf, want); err != nil {
41 t.Fatal(err)
42 }
43 got, err := readRecord(buf)
44 if err != nil {
45 t.Fatal(err)
46 }
47 if !bytes.Equal(got, want) {
48 t.Errorf("got %v, want %v", got, want)
49 }
50 }
51
52 func TestHeaderIO(t *testing.T) {
53 buf := &bytes.Buffer{}
54 want := []byte{1, 2, 3}
55 if err := writeHeader(buf, want); err != nil {
56 t.Fatal(err)
57 }
58 got, err := readHeader(buf)
59 if err != nil {
60 t.Fatal(err)
61 }
62 if !testutil.Equal(got, want) {
63 t.Errorf("got %v, want %v", got, want)
64 }
65
66
67 for _, contents := range []string{"", "badmagic", "gRPCReplay"} {
68 if _, err := readHeader(bytes.NewBufferString(contents)); err == nil {
69 t.Errorf("%q: got nil, want error", contents)
70 }
71 }
72 }
73
74 func TestEntryIO(t *testing.T) {
75 for i, want := range []*entry{
76 {
77 kind: rpb.Entry_REQUEST,
78 method: "method",
79 msg: message{msg: &rpb.Entry{}},
80 refIndex: 7,
81 },
82 {
83 kind: rpb.Entry_RESPONSE,
84 method: "method",
85 msg: message{err: status.Error(codes.NotFound, "not found")},
86 refIndex: 8,
87 },
88 {
89 kind: rpb.Entry_RECV,
90 method: "method",
91 msg: message{err: io.EOF},
92 refIndex: 3,
93 },
94 } {
95 buf := &bytes.Buffer{}
96 if err := writeEntry(buf, want); err != nil {
97 t.Fatal(err)
98 }
99 got, err := readEntry(buf)
100 if err != nil {
101 t.Fatal(err)
102 }
103 if !got.equal(want) {
104 t.Errorf("#%d: got %v, want %v", i, got, want)
105 }
106 }
107 }
108
109 var initialState = []byte{1, 2, 3}
110
111 func TestRecord(t *testing.T) {
112 buf := record(t, testService)
113
114 gotIstate, err := readHeader(buf)
115 if err != nil {
116 t.Fatal(err)
117 }
118 if !testutil.Equal(gotIstate, initialState) {
119 t.Fatalf("got %v, want %v", gotIstate, initialState)
120 }
121 item := &ipb.Item{Name: "a", Value: 1}
122 wantEntries := []*entry{
123
124 {
125 kind: rpb.Entry_REQUEST,
126 method: "/intstore.IntStore/Set",
127 msg: message{msg: item},
128 },
129 {
130 kind: rpb.Entry_RESPONSE,
131 msg: message{msg: &ipb.SetResponse{PrevValue: 0}},
132 refIndex: 1,
133 },
134
135 {
136 kind: rpb.Entry_REQUEST,
137 method: "/intstore.IntStore/Get",
138 msg: message{msg: &ipb.GetRequest{Name: "a"}},
139 },
140 {
141 kind: rpb.Entry_RESPONSE,
142 msg: message{msg: item},
143 refIndex: 3,
144 },
145 {
146 kind: rpb.Entry_REQUEST,
147 method: "/intstore.IntStore/Get",
148 msg: message{msg: &ipb.GetRequest{Name: "x"}},
149 },
150 {
151 kind: rpb.Entry_RESPONSE,
152 msg: message{err: status.Error(codes.NotFound, `"x"`)},
153 refIndex: 5,
154 },
155
156 {
157 kind: rpb.Entry_CREATE_STREAM,
158 method: "/intstore.IntStore/ListItems",
159 },
160 {
161 kind: rpb.Entry_SEND,
162 msg: message{msg: &ipb.ListItemsRequest{}},
163 refIndex: 7,
164 },
165 {
166 kind: rpb.Entry_RECV,
167 msg: message{msg: item},
168 refIndex: 7,
169 },
170 {
171 kind: rpb.Entry_RECV,
172 msg: message{err: io.EOF},
173 refIndex: 7,
174 },
175
176 {
177 kind: rpb.Entry_CREATE_STREAM,
178 method: "/intstore.IntStore/SetStream",
179 },
180 {
181 kind: rpb.Entry_SEND,
182 msg: message{msg: &ipb.Item{Name: "b", Value: 2}},
183 refIndex: 11,
184 },
185 {
186 kind: rpb.Entry_SEND,
187 msg: message{msg: &ipb.Item{Name: "c", Value: 3}},
188 refIndex: 11,
189 },
190 {
191 kind: rpb.Entry_RECV,
192 msg: message{msg: &ipb.Summary{Count: 2}},
193 refIndex: 11,
194 },
195
196
197 {
198 kind: rpb.Entry_CREATE_STREAM,
199 method: "/intstore.IntStore/StreamChat",
200 },
201 {
202 kind: rpb.Entry_SEND,
203 msg: message{msg: &ipb.Item{Name: "d", Value: 4}},
204 refIndex: 15,
205 },
206 {
207 kind: rpb.Entry_RECV,
208 msg: message{msg: &ipb.Item{Name: "d", Value: 4}},
209 refIndex: 15,
210 },
211 {
212 kind: rpb.Entry_SEND,
213 msg: message{msg: &ipb.Item{Name: "e", Value: 5}},
214 refIndex: 15,
215 },
216 {
217 kind: rpb.Entry_RECV,
218 msg: message{msg: &ipb.Item{Name: "e", Value: 5}},
219 refIndex: 15,
220 },
221 {
222 kind: rpb.Entry_RECV,
223 msg: message{err: io.EOF},
224 refIndex: 15,
225 },
226 }
227 for i, w := range wantEntries {
228 g, err := readEntry(buf)
229 if err != nil {
230 t.Fatalf("#%d: %v", i+1, err)
231 }
232 if !g.equal(w) {
233 t.Errorf("#%d:\ngot %+v\nwant %+v", i+1, g, w)
234 }
235 }
236 g, err := readEntry(buf)
237 if err != nil {
238 t.Fatal(err)
239 }
240 if g != nil {
241 t.Errorf("\ngot %+v\nwant nil", g)
242 }
243 }
244
245 func TestReplay(t *testing.T) {
246 buf := record(t, testService)
247 replay(t, buf, testService)
248 }
249
250 func record(t *testing.T, run func(*testing.T, *grpc.ClientConn)) *bytes.Buffer {
251 srv := newIntStoreServer()
252 defer srv.stop()
253
254 buf := &bytes.Buffer{}
255 rec, err := NewRecorderWriter(buf, initialState)
256 if err != nil {
257 t.Fatal(err)
258 }
259 conn, err := grpc.Dial(srv.Addr,
260 append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...)
261 if err != nil {
262 t.Fatal(err)
263 }
264 defer conn.Close()
265 run(t, conn)
266 if err := rec.Close(); err != nil {
267 t.Fatal(err)
268 }
269 return buf
270 }
271
272 func replay(t *testing.T, buf *bytes.Buffer, run func(*testing.T, *grpc.ClientConn)) {
273 rep, err := NewReplayerReader(buf)
274 if err != nil {
275 t.Fatal(err)
276 }
277 defer rep.Close()
278 if got, want := rep.Initial(), initialState; !testutil.Equal(got, want) {
279 t.Fatalf("got %v, want %v", got, want)
280 }
281
282 conn, err := rep.Connection()
283 if err != nil {
284 t.Fatal(err)
285 }
286 defer conn.Close()
287 run(t, conn)
288 }
289
290 func testService(t *testing.T, conn *grpc.ClientConn) {
291 client := ipb.NewIntStoreClient(conn)
292 ctx := context.Background()
293 item := &ipb.Item{Name: "a", Value: 1}
294 res, err := client.Set(ctx, item)
295 if err != nil {
296 t.Fatal(err)
297 }
298 if res.PrevValue != 0 {
299 t.Errorf("got %d, want 0", res.PrevValue)
300 }
301 got, err := client.Get(ctx, &ipb.GetRequest{Name: "a"})
302 if err != nil {
303 t.Fatal(err)
304 }
305 if !proto.Equal(got, item) {
306 t.Errorf("got %v, want %v", got, item)
307 }
308 _, err = client.Get(ctx, &ipb.GetRequest{Name: "x"})
309 if err == nil {
310 t.Fatal("got nil, want error")
311 }
312 if _, ok := status.FromError(err); !ok {
313 t.Errorf("got error type %T, want a grpc/status.Status", err)
314 }
315
316 gotItems := listItems(t, client, 0)
317 compareLists(t, gotItems, []*ipb.Item{item})
318
319 ssc, err := client.SetStream(ctx)
320 if err != nil {
321 t.Fatal(err)
322 }
323
324 must := func(err error) {
325 if err != nil {
326 t.Fatal(err)
327 }
328 }
329
330 for i, name := range []string{"b", "c"} {
331 must(ssc.Send(&ipb.Item{Name: name, Value: int32(i + 2)}))
332 }
333 summary, err := ssc.CloseAndRecv()
334 if err != nil {
335 t.Fatal(err)
336 }
337 if got, want := summary.Count, int32(2); got != want {
338 t.Fatalf("got %d, want %d", got, want)
339 }
340
341 chatc, err := client.StreamChat(ctx)
342 if err != nil {
343 t.Fatal(err)
344 }
345 for i, name := range []string{"d", "e"} {
346 item := &ipb.Item{Name: name, Value: int32(i + 4)}
347 must(chatc.Send(item))
348 got, err := chatc.Recv()
349 if err != nil {
350 t.Fatal(err)
351 }
352 if !proto.Equal(got, item) {
353 t.Errorf("got %v, want %v", got, item)
354 }
355 }
356 must(chatc.CloseSend())
357 if _, err := chatc.Recv(); err != io.EOF {
358 t.Fatalf("got %v, want EOF", err)
359 }
360 }
361
362 func listItems(t *testing.T, client ipb.IntStoreClient, greaterThan int) []*ipb.Item {
363 t.Helper()
364 lic, err := client.ListItems(context.Background(), &ipb.ListItemsRequest{GreaterThan: int32(greaterThan)})
365 if err != nil {
366 t.Fatal(err)
367 }
368 var items []*ipb.Item
369 for i := 0; ; i++ {
370 item, err := lic.Recv()
371 if err == io.EOF {
372 break
373 }
374 if err != nil {
375 t.Fatal(err)
376 }
377 items = append(items, item)
378 }
379 return items
380 }
381
382 func compareLists(t *testing.T, got, want []*ipb.Item) {
383 t.Helper()
384 diff := cmp.Diff(got, want, cmp.Comparer(proto.Equal), cmpopts.SortSlices(func(i1, i2 *ipb.Item) bool {
385 return i1.Value < i2.Value
386 }))
387 if diff != "" {
388 t.Error(diff)
389 }
390 }
391
392 func TestRecorderBeforeFunc(t *testing.T) {
393 var tests = []struct {
394 name string
395 msg, wantRespMsg, wantEntryMsg *ipb.Item
396 f func(string, proto.Message) error
397 wantErr bool
398 }{
399 {
400 name: "BeforeFunc should modify messages saved, but not alter what is sent/received to/from services",
401 msg: &ipb.Item{Name: "foo", Value: 1},
402 wantEntryMsg: &ipb.Item{Name: "bar", Value: 2},
403 wantRespMsg: &ipb.Item{Name: "foo", Value: 1},
404 f: func(method string, m proto.Message) error {
405
406 if !strings.HasSuffix(method, "Set") {
407 return nil
408 }
409 if _, ok := m.(*ipb.Item); !ok {
410 return nil
411 }
412
413 item := m.(*ipb.Item)
414 item.Name = "bar"
415 item.Value = 2
416 return nil
417 },
418 },
419 {
420 name: "BeforeFunc should not be able to alter returned responses",
421 msg: &ipb.Item{Name: "foo", Value: 1},
422 wantRespMsg: &ipb.Item{Name: "foo", Value: 1},
423 f: func(method string, m proto.Message) error {
424
425 if !strings.HasSuffix(method, "Get") {
426 return nil
427 }
428 if _, ok := m.(*ipb.Item); !ok {
429 return nil
430 }
431
432 item := m.(*ipb.Item)
433 item.Value = 2
434 return nil
435 },
436 },
437 {
438 name: "Errors should cause the RPC send to fail",
439 msg: &ipb.Item{},
440 f: func(_ string, _ proto.Message) error {
441 return errors.New("err")
442 },
443 wantErr: true,
444 },
445 }
446
447 for _, tc := range tests {
448
449 func() {
450 srv := newIntStoreServer()
451 defer srv.stop()
452
453 var b bytes.Buffer
454 r, err := NewRecorderWriter(&b, nil)
455 if err != nil {
456 t.Error(err)
457 return
458 }
459 r.BeforeFunc = tc.f
460 ctx := context.Background()
461 conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, r.DialOptions()...)...)
462 if err != nil {
463 t.Error(err)
464 return
465 }
466 defer conn.Close()
467
468 client := ipb.NewIntStoreClient(conn)
469 _, err = client.Set(ctx, tc.msg)
470 switch {
471 case err != nil && !tc.wantErr:
472 t.Error(err)
473 return
474 case err == nil && tc.wantErr:
475 t.Errorf("got nil; want error")
476 return
477 case err != nil:
478
479 return
480 }
481
482 if tc.wantRespMsg != nil {
483 got, err := client.Get(ctx, &ipb.GetRequest{Name: tc.msg.GetName()})
484 if err != nil {
485 t.Error(err)
486 return
487 }
488 if !cmp.Equal(got, tc.wantRespMsg, protocmp.Transform()) {
489 t.Errorf("got %+v; want %+v", got, tc.wantRespMsg)
490 }
491 }
492
493 r.Close()
494
495 if tc.wantEntryMsg != nil {
496 _, _ = readHeader(&b)
497 e, err := readEntry(&b)
498 if err != nil {
499 t.Error(err)
500 return
501 }
502 got := e.msg.msg.(*ipb.Item)
503 if !cmp.Equal(got, tc.wantEntryMsg, protocmp.Transform()) {
504 t.Errorf("got %v; want %v", got, tc.wantEntryMsg)
505 }
506 }
507 }()
508 }
509 }
510
511 func TestReplayerBeforeFunc(t *testing.T) {
512 var tests = []struct {
513 name string
514 msg, reqMsg *ipb.Item
515 f func(string, proto.Message) error
516 wantErr bool
517 }{
518 {
519 name: "BeforeFunc should modify messages sent before they are passed to the replayer",
520 msg: &ipb.Item{Name: "foo", Value: 1},
521 reqMsg: &ipb.Item{Name: "bar", Value: 1},
522 f: func(method string, m proto.Message) error {
523 item := m.(*ipb.Item)
524 item.Name = "foo"
525 return nil
526 },
527 },
528 {
529 name: "Errors should cause the RPC send to fail",
530 msg: &ipb.Item{},
531 f: func(_ string, _ proto.Message) error {
532 return errors.New("err")
533 },
534 wantErr: true,
535 },
536 }
537
538 for _, tc := range tests {
539
540 func() {
541 srv := newIntStoreServer()
542 defer srv.stop()
543
544 var b bytes.Buffer
545 rec, err := NewRecorderWriter(&b, nil)
546 if err != nil {
547 t.Error(err)
548 return
549 }
550 ctx := context.Background()
551 conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...)
552 if err != nil {
553 t.Error(err)
554 return
555 }
556 defer conn.Close()
557
558 client := ipb.NewIntStoreClient(conn)
559 _, err = client.Set(ctx, tc.msg)
560 if err != nil {
561 t.Error(err)
562 return
563 }
564 rec.Close()
565
566 rep, err := NewReplayerReader(&b)
567 if err != nil {
568 t.Error(err)
569 return
570 }
571 rep.BeforeFunc = tc.f
572 conn, err = grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rep.DialOptions()...)...)
573 if err != nil {
574 t.Error(err)
575 return
576 }
577 defer conn.Close()
578
579 client = ipb.NewIntStoreClient(conn)
580 _, err = client.Set(ctx, tc.reqMsg)
581 switch {
582 case err != nil && !tc.wantErr:
583 t.Error(err)
584 case err == nil && tc.wantErr:
585 t.Errorf("got nil; want error")
586 }
587 }()
588 }
589 }
590
591 func TestOutOfOrderStreamReplay(t *testing.T) {
592
593
594 items := []*ipb.Item{
595 {Name: "a", Value: 1},
596 {Name: "b", Value: 2},
597 {Name: "c", Value: 3},
598 }
599 run := func(t *testing.T, conn *grpc.ClientConn, arg1, arg2 int) {
600 client := ipb.NewIntStoreClient(conn)
601 ctx := context.Background()
602
603 for _, item := range items {
604 _, err := client.Set(ctx, item)
605 if err != nil {
606 t.Fatal(err)
607 }
608 }
609
610 compareLists(t, listItems(t, client, arg1), items[arg1:])
611 compareLists(t, listItems(t, client, arg2), items[arg2:])
612 }
613
614 srv := newIntStoreServer()
615 defer srv.stop()
616
617
618 buf := record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
619 replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
620
621
622 buf = record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
623 replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 2, 1) })
624 }
625
View as plain text