...

Source file src/cloud.google.com/go/rpcreplay/rpcreplay_test.go

Documentation: cloud.google.com/go/rpcreplay

     1  // Copyright 2017 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rpcreplay
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"errors"
    21  	"io"
    22  	"strings"
    23  	"testing"
    24  
    25  	"cloud.google.com/go/internal/testutil"
    26  	ipb "cloud.google.com/go/rpcreplay/proto/intstore"
    27  	rpb "cloud.google.com/go/rpcreplay/proto/rpcreplay"
    28  	"github.com/google/go-cmp/cmp"
    29  	"github.com/google/go-cmp/cmp/cmpopts"
    30  	"google.golang.org/grpc"
    31  	"google.golang.org/grpc/codes"
    32  	"google.golang.org/grpc/status"
    33  	"google.golang.org/protobuf/proto"
    34  	"google.golang.org/protobuf/testing/protocmp"
    35  )
    36  
    37  func TestRecordIO(t *testing.T) {
    38  	buf := &bytes.Buffer{}
    39  	want := []byte{1, 2, 3}
    40  	if err := writeRecord(buf, want); err != nil {
    41  		t.Fatal(err)
    42  	}
    43  	got, err := readRecord(buf)
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  	if !bytes.Equal(got, want) {
    48  		t.Errorf("got %v, want %v", got, want)
    49  	}
    50  }
    51  
    52  func TestHeaderIO(t *testing.T) {
    53  	buf := &bytes.Buffer{}
    54  	want := []byte{1, 2, 3}
    55  	if err := writeHeader(buf, want); err != nil {
    56  		t.Fatal(err)
    57  	}
    58  	got, err := readHeader(buf)
    59  	if err != nil {
    60  		t.Fatal(err)
    61  	}
    62  	if !testutil.Equal(got, want) {
    63  		t.Errorf("got %v, want %v", got, want)
    64  	}
    65  
    66  	// readHeader errors
    67  	for _, contents := range []string{"", "badmagic", "gRPCReplay"} {
    68  		if _, err := readHeader(bytes.NewBufferString(contents)); err == nil {
    69  			t.Errorf("%q: got nil, want error", contents)
    70  		}
    71  	}
    72  }
    73  
    74  func TestEntryIO(t *testing.T) {
    75  	for i, want := range []*entry{
    76  		{
    77  			kind:     rpb.Entry_REQUEST,
    78  			method:   "method",
    79  			msg:      message{msg: &rpb.Entry{}},
    80  			refIndex: 7,
    81  		},
    82  		{
    83  			kind:     rpb.Entry_RESPONSE,
    84  			method:   "method",
    85  			msg:      message{err: status.Error(codes.NotFound, "not found")},
    86  			refIndex: 8,
    87  		},
    88  		{
    89  			kind:     rpb.Entry_RECV,
    90  			method:   "method",
    91  			msg:      message{err: io.EOF},
    92  			refIndex: 3,
    93  		},
    94  	} {
    95  		buf := &bytes.Buffer{}
    96  		if err := writeEntry(buf, want); err != nil {
    97  			t.Fatal(err)
    98  		}
    99  		got, err := readEntry(buf)
   100  		if err != nil {
   101  			t.Fatal(err)
   102  		}
   103  		if !got.equal(want) {
   104  			t.Errorf("#%d: got %v, want %v", i, got, want)
   105  		}
   106  	}
   107  }
   108  
   109  var initialState = []byte{1, 2, 3}
   110  
   111  func TestRecord(t *testing.T) {
   112  	buf := record(t, testService)
   113  
   114  	gotIstate, err := readHeader(buf)
   115  	if err != nil {
   116  		t.Fatal(err)
   117  	}
   118  	if !testutil.Equal(gotIstate, initialState) {
   119  		t.Fatalf("got %v, want %v", gotIstate, initialState)
   120  	}
   121  	item := &ipb.Item{Name: "a", Value: 1}
   122  	wantEntries := []*entry{
   123  		// Set
   124  		{
   125  			kind:   rpb.Entry_REQUEST,
   126  			method: "/intstore.IntStore/Set",
   127  			msg:    message{msg: item},
   128  		},
   129  		{
   130  			kind:     rpb.Entry_RESPONSE,
   131  			msg:      message{msg: &ipb.SetResponse{PrevValue: 0}},
   132  			refIndex: 1,
   133  		},
   134  		// Get
   135  		{
   136  			kind:   rpb.Entry_REQUEST,
   137  			method: "/intstore.IntStore/Get",
   138  			msg:    message{msg: &ipb.GetRequest{Name: "a"}},
   139  		},
   140  		{
   141  			kind:     rpb.Entry_RESPONSE,
   142  			msg:      message{msg: item},
   143  			refIndex: 3,
   144  		},
   145  		{
   146  			kind:   rpb.Entry_REQUEST,
   147  			method: "/intstore.IntStore/Get",
   148  			msg:    message{msg: &ipb.GetRequest{Name: "x"}},
   149  		},
   150  		{
   151  			kind:     rpb.Entry_RESPONSE,
   152  			msg:      message{err: status.Error(codes.NotFound, `"x"`)},
   153  			refIndex: 5,
   154  		},
   155  		// ListItems
   156  		{ // entry #7
   157  			kind:   rpb.Entry_CREATE_STREAM,
   158  			method: "/intstore.IntStore/ListItems",
   159  		},
   160  		{
   161  			kind:     rpb.Entry_SEND,
   162  			msg:      message{msg: &ipb.ListItemsRequest{}},
   163  			refIndex: 7,
   164  		},
   165  		{
   166  			kind:     rpb.Entry_RECV,
   167  			msg:      message{msg: item},
   168  			refIndex: 7,
   169  		},
   170  		{
   171  			kind:     rpb.Entry_RECV,
   172  			msg:      message{err: io.EOF},
   173  			refIndex: 7,
   174  		},
   175  		// SetStream
   176  		{ // entry #11
   177  			kind:   rpb.Entry_CREATE_STREAM,
   178  			method: "/intstore.IntStore/SetStream",
   179  		},
   180  		{
   181  			kind:     rpb.Entry_SEND,
   182  			msg:      message{msg: &ipb.Item{Name: "b", Value: 2}},
   183  			refIndex: 11,
   184  		},
   185  		{
   186  			kind:     rpb.Entry_SEND,
   187  			msg:      message{msg: &ipb.Item{Name: "c", Value: 3}},
   188  			refIndex: 11,
   189  		},
   190  		{
   191  			kind:     rpb.Entry_RECV,
   192  			msg:      message{msg: &ipb.Summary{Count: 2}},
   193  			refIndex: 11,
   194  		},
   195  
   196  		// StreamChat
   197  		{ // entry #15
   198  			kind:   rpb.Entry_CREATE_STREAM,
   199  			method: "/intstore.IntStore/StreamChat",
   200  		},
   201  		{
   202  			kind:     rpb.Entry_SEND,
   203  			msg:      message{msg: &ipb.Item{Name: "d", Value: 4}},
   204  			refIndex: 15,
   205  		},
   206  		{
   207  			kind:     rpb.Entry_RECV,
   208  			msg:      message{msg: &ipb.Item{Name: "d", Value: 4}},
   209  			refIndex: 15,
   210  		},
   211  		{
   212  			kind:     rpb.Entry_SEND,
   213  			msg:      message{msg: &ipb.Item{Name: "e", Value: 5}},
   214  			refIndex: 15,
   215  		},
   216  		{
   217  			kind:     rpb.Entry_RECV,
   218  			msg:      message{msg: &ipb.Item{Name: "e", Value: 5}},
   219  			refIndex: 15,
   220  		},
   221  		{
   222  			kind:     rpb.Entry_RECV,
   223  			msg:      message{err: io.EOF},
   224  			refIndex: 15,
   225  		},
   226  	}
   227  	for i, w := range wantEntries {
   228  		g, err := readEntry(buf)
   229  		if err != nil {
   230  			t.Fatalf("#%d: %v", i+1, err)
   231  		}
   232  		if !g.equal(w) {
   233  			t.Errorf("#%d:\ngot  %+v\nwant %+v", i+1, g, w)
   234  		}
   235  	}
   236  	g, err := readEntry(buf)
   237  	if err != nil {
   238  		t.Fatal(err)
   239  	}
   240  	if g != nil {
   241  		t.Errorf("\ngot  %+v\nwant nil", g)
   242  	}
   243  }
   244  
   245  func TestReplay(t *testing.T) {
   246  	buf := record(t, testService)
   247  	replay(t, buf, testService)
   248  }
   249  
   250  func record(t *testing.T, run func(*testing.T, *grpc.ClientConn)) *bytes.Buffer {
   251  	srv := newIntStoreServer()
   252  	defer srv.stop()
   253  
   254  	buf := &bytes.Buffer{}
   255  	rec, err := NewRecorderWriter(buf, initialState)
   256  	if err != nil {
   257  		t.Fatal(err)
   258  	}
   259  	conn, err := grpc.Dial(srv.Addr,
   260  		append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...)
   261  	if err != nil {
   262  		t.Fatal(err)
   263  	}
   264  	defer conn.Close()
   265  	run(t, conn)
   266  	if err := rec.Close(); err != nil {
   267  		t.Fatal(err)
   268  	}
   269  	return buf
   270  }
   271  
   272  func replay(t *testing.T, buf *bytes.Buffer, run func(*testing.T, *grpc.ClientConn)) {
   273  	rep, err := NewReplayerReader(buf)
   274  	if err != nil {
   275  		t.Fatal(err)
   276  	}
   277  	defer rep.Close()
   278  	if got, want := rep.Initial(), initialState; !testutil.Equal(got, want) {
   279  		t.Fatalf("got %v, want %v", got, want)
   280  	}
   281  	// Replay the test.
   282  	conn, err := rep.Connection()
   283  	if err != nil {
   284  		t.Fatal(err)
   285  	}
   286  	defer conn.Close()
   287  	run(t, conn)
   288  }
   289  
   290  func testService(t *testing.T, conn *grpc.ClientConn) {
   291  	client := ipb.NewIntStoreClient(conn)
   292  	ctx := context.Background()
   293  	item := &ipb.Item{Name: "a", Value: 1}
   294  	res, err := client.Set(ctx, item)
   295  	if err != nil {
   296  		t.Fatal(err)
   297  	}
   298  	if res.PrevValue != 0 {
   299  		t.Errorf("got %d, want 0", res.PrevValue)
   300  	}
   301  	got, err := client.Get(ctx, &ipb.GetRequest{Name: "a"})
   302  	if err != nil {
   303  		t.Fatal(err)
   304  	}
   305  	if !proto.Equal(got, item) {
   306  		t.Errorf("got %v, want %v", got, item)
   307  	}
   308  	_, err = client.Get(ctx, &ipb.GetRequest{Name: "x"})
   309  	if err == nil {
   310  		t.Fatal("got nil, want error")
   311  	}
   312  	if _, ok := status.FromError(err); !ok {
   313  		t.Errorf("got error type %T, want a grpc/status.Status", err)
   314  	}
   315  
   316  	gotItems := listItems(t, client, 0)
   317  	compareLists(t, gotItems, []*ipb.Item{item})
   318  
   319  	ssc, err := client.SetStream(ctx)
   320  	if err != nil {
   321  		t.Fatal(err)
   322  	}
   323  
   324  	must := func(err error) {
   325  		if err != nil {
   326  			t.Fatal(err)
   327  		}
   328  	}
   329  
   330  	for i, name := range []string{"b", "c"} {
   331  		must(ssc.Send(&ipb.Item{Name: name, Value: int32(i + 2)}))
   332  	}
   333  	summary, err := ssc.CloseAndRecv()
   334  	if err != nil {
   335  		t.Fatal(err)
   336  	}
   337  	if got, want := summary.Count, int32(2); got != want {
   338  		t.Fatalf("got %d, want %d", got, want)
   339  	}
   340  
   341  	chatc, err := client.StreamChat(ctx)
   342  	if err != nil {
   343  		t.Fatal(err)
   344  	}
   345  	for i, name := range []string{"d", "e"} {
   346  		item := &ipb.Item{Name: name, Value: int32(i + 4)}
   347  		must(chatc.Send(item))
   348  		got, err := chatc.Recv()
   349  		if err != nil {
   350  			t.Fatal(err)
   351  		}
   352  		if !proto.Equal(got, item) {
   353  			t.Errorf("got %v, want %v", got, item)
   354  		}
   355  	}
   356  	must(chatc.CloseSend())
   357  	if _, err := chatc.Recv(); err != io.EOF {
   358  		t.Fatalf("got %v, want EOF", err)
   359  	}
   360  }
   361  
   362  func listItems(t *testing.T, client ipb.IntStoreClient, greaterThan int) []*ipb.Item {
   363  	t.Helper()
   364  	lic, err := client.ListItems(context.Background(), &ipb.ListItemsRequest{GreaterThan: int32(greaterThan)})
   365  	if err != nil {
   366  		t.Fatal(err)
   367  	}
   368  	var items []*ipb.Item
   369  	for i := 0; ; i++ {
   370  		item, err := lic.Recv()
   371  		if err == io.EOF {
   372  			break
   373  		}
   374  		if err != nil {
   375  			t.Fatal(err)
   376  		}
   377  		items = append(items, item)
   378  	}
   379  	return items
   380  }
   381  
   382  func compareLists(t *testing.T, got, want []*ipb.Item) {
   383  	t.Helper()
   384  	diff := cmp.Diff(got, want, cmp.Comparer(proto.Equal), cmpopts.SortSlices(func(i1, i2 *ipb.Item) bool {
   385  		return i1.Value < i2.Value
   386  	}))
   387  	if diff != "" {
   388  		t.Error(diff)
   389  	}
   390  }
   391  
   392  func TestRecorderBeforeFunc(t *testing.T) {
   393  	var tests = []struct {
   394  		name                           string
   395  		msg, wantRespMsg, wantEntryMsg *ipb.Item
   396  		f                              func(string, proto.Message) error
   397  		wantErr                        bool
   398  	}{
   399  		{
   400  			name:         "BeforeFunc should modify messages saved, but not alter what is sent/received to/from services",
   401  			msg:          &ipb.Item{Name: "foo", Value: 1},
   402  			wantEntryMsg: &ipb.Item{Name: "bar", Value: 2},
   403  			wantRespMsg:  &ipb.Item{Name: "foo", Value: 1},
   404  			f: func(method string, m proto.Message) error {
   405  				// This callback only runs when Set is called.
   406  				if !strings.HasSuffix(method, "Set") {
   407  					return nil
   408  				}
   409  				if _, ok := m.(*ipb.Item); !ok {
   410  					return nil
   411  				}
   412  
   413  				item := m.(*ipb.Item)
   414  				item.Name = "bar"
   415  				item.Value = 2
   416  				return nil
   417  			},
   418  		},
   419  		{
   420  			name:        "BeforeFunc should not be able to alter returned responses",
   421  			msg:         &ipb.Item{Name: "foo", Value: 1},
   422  			wantRespMsg: &ipb.Item{Name: "foo", Value: 1},
   423  			f: func(method string, m proto.Message) error {
   424  				// This callback only runs when Get is called.
   425  				if !strings.HasSuffix(method, "Get") {
   426  					return nil
   427  				}
   428  				if _, ok := m.(*ipb.Item); !ok {
   429  					return nil
   430  				}
   431  
   432  				item := m.(*ipb.Item)
   433  				item.Value = 2
   434  				return nil
   435  			},
   436  		},
   437  		{
   438  			name: "Errors should cause the RPC send to fail",
   439  			msg:  &ipb.Item{},
   440  			f: func(_ string, _ proto.Message) error {
   441  				return errors.New("err")
   442  			},
   443  			wantErr: true,
   444  		},
   445  	}
   446  
   447  	for _, tc := range tests {
   448  		// Wrap test cases in a func so defers execute correctly.
   449  		func() {
   450  			srv := newIntStoreServer()
   451  			defer srv.stop()
   452  
   453  			var b bytes.Buffer
   454  			r, err := NewRecorderWriter(&b, nil)
   455  			if err != nil {
   456  				t.Error(err)
   457  				return
   458  			}
   459  			r.BeforeFunc = tc.f
   460  			ctx := context.Background()
   461  			conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, r.DialOptions()...)...)
   462  			if err != nil {
   463  				t.Error(err)
   464  				return
   465  			}
   466  			defer conn.Close()
   467  
   468  			client := ipb.NewIntStoreClient(conn)
   469  			_, err = client.Set(ctx, tc.msg)
   470  			switch {
   471  			case err != nil && !tc.wantErr:
   472  				t.Error(err)
   473  				return
   474  			case err == nil && tc.wantErr:
   475  				t.Errorf("got nil; want error")
   476  				return
   477  			case err != nil:
   478  				// Error found as expected, don't check Get().
   479  				return
   480  			}
   481  
   482  			if tc.wantRespMsg != nil {
   483  				got, err := client.Get(ctx, &ipb.GetRequest{Name: tc.msg.GetName()})
   484  				if err != nil {
   485  					t.Error(err)
   486  					return
   487  				}
   488  				if !cmp.Equal(got, tc.wantRespMsg, protocmp.Transform()) {
   489  					t.Errorf("got %+v; want %+v", got, tc.wantRespMsg)
   490  				}
   491  			}
   492  
   493  			r.Close()
   494  
   495  			if tc.wantEntryMsg != nil {
   496  				_, _ = readHeader(&b)
   497  				e, err := readEntry(&b)
   498  				if err != nil {
   499  					t.Error(err)
   500  					return
   501  				}
   502  				got := e.msg.msg.(*ipb.Item)
   503  				if !cmp.Equal(got, tc.wantEntryMsg, protocmp.Transform()) {
   504  					t.Errorf("got %v; want %v", got, tc.wantEntryMsg)
   505  				}
   506  			}
   507  		}()
   508  	}
   509  }
   510  
   511  func TestReplayerBeforeFunc(t *testing.T) {
   512  	var tests = []struct {
   513  		name        string
   514  		msg, reqMsg *ipb.Item
   515  		f           func(string, proto.Message) error
   516  		wantErr     bool
   517  	}{
   518  		{
   519  			name:   "BeforeFunc should modify messages sent before they are passed to the replayer",
   520  			msg:    &ipb.Item{Name: "foo", Value: 1},
   521  			reqMsg: &ipb.Item{Name: "bar", Value: 1},
   522  			f: func(method string, m proto.Message) error {
   523  				item := m.(*ipb.Item)
   524  				item.Name = "foo"
   525  				return nil
   526  			},
   527  		},
   528  		{
   529  			name: "Errors should cause the RPC send to fail",
   530  			msg:  &ipb.Item{},
   531  			f: func(_ string, _ proto.Message) error {
   532  				return errors.New("err")
   533  			},
   534  			wantErr: true,
   535  		},
   536  	}
   537  
   538  	for _, tc := range tests {
   539  		// Wrap test cases in a func so defers execute correctly.
   540  		func() {
   541  			srv := newIntStoreServer()
   542  			defer srv.stop()
   543  
   544  			var b bytes.Buffer
   545  			rec, err := NewRecorderWriter(&b, nil)
   546  			if err != nil {
   547  				t.Error(err)
   548  				return
   549  			}
   550  			ctx := context.Background()
   551  			conn, err := grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rec.DialOptions()...)...)
   552  			if err != nil {
   553  				t.Error(err)
   554  				return
   555  			}
   556  			defer conn.Close()
   557  
   558  			client := ipb.NewIntStoreClient(conn)
   559  			_, err = client.Set(ctx, tc.msg)
   560  			if err != nil {
   561  				t.Error(err)
   562  				return
   563  			}
   564  			rec.Close()
   565  
   566  			rep, err := NewReplayerReader(&b)
   567  			if err != nil {
   568  				t.Error(err)
   569  				return
   570  			}
   571  			rep.BeforeFunc = tc.f
   572  			conn, err = grpc.DialContext(ctx, srv.Addr, append([]grpc.DialOption{grpc.WithInsecure()}, rep.DialOptions()...)...)
   573  			if err != nil {
   574  				t.Error(err)
   575  				return
   576  			}
   577  			defer conn.Close()
   578  
   579  			client = ipb.NewIntStoreClient(conn)
   580  			_, err = client.Set(ctx, tc.reqMsg)
   581  			switch {
   582  			case err != nil && !tc.wantErr:
   583  				t.Error(err)
   584  			case err == nil && tc.wantErr:
   585  				t.Errorf("got nil; want error")
   586  			}
   587  		}()
   588  	}
   589  }
   590  
   591  func TestOutOfOrderStreamReplay(t *testing.T) {
   592  	// Check that streams are matched by method and first request sent, if any.
   593  
   594  	items := []*ipb.Item{
   595  		{Name: "a", Value: 1},
   596  		{Name: "b", Value: 2},
   597  		{Name: "c", Value: 3},
   598  	}
   599  	run := func(t *testing.T, conn *grpc.ClientConn, arg1, arg2 int) {
   600  		client := ipb.NewIntStoreClient(conn)
   601  		ctx := context.Background()
   602  		// Set some items.
   603  		for _, item := range items {
   604  			_, err := client.Set(ctx, item)
   605  			if err != nil {
   606  				t.Fatal(err)
   607  			}
   608  		}
   609  		// List them twice, with different requests.
   610  		compareLists(t, listItems(t, client, arg1), items[arg1:])
   611  		compareLists(t, listItems(t, client, arg2), items[arg2:])
   612  	}
   613  
   614  	srv := newIntStoreServer()
   615  	defer srv.stop()
   616  
   617  	// Replay in the same order.
   618  	buf := record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
   619  	replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
   620  
   621  	// Replay in a different order.
   622  	buf = record(t, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 1, 2) })
   623  	replay(t, buf, func(t *testing.T, conn *grpc.ClientConn) { run(t, conn, 2, 1) })
   624  }
   625  

View as plain text