...

Source file src/github.com/containerd/ttrpc/server_test.go

Documentation: github.com/containerd/ttrpc

     1  /*
     2     Copyright The containerd Authors.
     3  
     4     Licensed under the Apache License, Version 2.0 (the "License");
     5     you may not use this file except in compliance with the License.
     6     You may obtain a copy of the License at
     7  
     8         http://www.apache.org/licenses/LICENSE-2.0
     9  
    10     Unless required by applicable law or agreed to in writing, software
    11     distributed under the License is distributed on an "AS IS" BASIS,
    12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13     See the License for the specific language governing permissions and
    14     limitations under the License.
    15  */
    16  
    17  package ttrpc
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"runtime"
    26  	"strings"
    27  	"sync"
    28  	"syscall"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/containerd/ttrpc/internal"
    33  	"google.golang.org/grpc/codes"
    34  	"google.golang.org/grpc/status"
    35  	"google.golang.org/protobuf/proto"
    36  )
    37  
    38  const serviceName = "testService"
    39  
    40  // testingService is our prototype service definition for use in testing the full model.
    41  //
    42  // Typically, this is generated. We define it here to ensure that that package
    43  // primitive has what is required for generated code.
    44  type testingService interface {
    45  	Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error)
    46  }
    47  
    48  type testingClient struct {
    49  	client *Client
    50  }
    51  
    52  func newTestingClient(client *Client) *testingClient {
    53  	return &testingClient{
    54  		client: client,
    55  	}
    56  }
    57  
    58  func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) {
    59  	var tp internal.TestPayload
    60  	return &tp, tc.client.Call(ctx, serviceName, "Test", req, &tp)
    61  }
    62  
    63  // testingServer is what would be implemented by the user of this package.
    64  type testingServer struct{}
    65  
    66  func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) {
    67  	tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)}
    68  	if dl, ok := ctx.Deadline(); ok {
    69  		tp.Deadline = dl.UnixNano()
    70  	}
    71  
    72  	if v, ok := GetMetadataValue(ctx, "foo"); ok {
    73  		tp.Metadata = v
    74  	}
    75  
    76  	return tp, nil
    77  }
    78  
    79  // registerTestingService mocks more of what is generated code. Unlike grpc, we
    80  // register with a closure so that the descriptor is allocated only on
    81  // registration.
    82  func registerTestingService(srv *Server, svc testingService) {
    83  	srv.Register(serviceName, map[string]Method{
    84  		"Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
    85  			var req internal.TestPayload
    86  			if err := unmarshal(&req); err != nil {
    87  				return nil, err
    88  			}
    89  			return svc.Test(ctx, &req)
    90  		},
    91  	})
    92  }
    93  
    94  func protoEqual(a, b proto.Message) (bool, error) {
    95  	ma, err := proto.Marshal(a)
    96  	if err != nil {
    97  		return false, err
    98  	}
    99  	mb, err := proto.Marshal(b)
   100  	if err != nil {
   101  		return false, err
   102  	}
   103  	return bytes.Equal(ma, mb), nil
   104  }
   105  
   106  func TestServer(t *testing.T) {
   107  	var (
   108  		ctx             = context.Background()
   109  		server          = mustServer(t)(NewServer())
   110  		testImpl        = &testingServer{}
   111  		addr, listener  = newTestListener(t)
   112  		client, cleanup = newTestClient(t, addr)
   113  		tclient         = newTestingClient(client)
   114  	)
   115  
   116  	defer listener.Close()
   117  	defer cleanup()
   118  
   119  	registerTestingService(server, testImpl)
   120  
   121  	go server.Serve(ctx, listener)
   122  	defer server.Shutdown(ctx)
   123  
   124  	testCases := []string{"bar", "baz"}
   125  	results := make(chan callResult, len(testCases))
   126  	for _, tc := range testCases {
   127  		go func(expected string) {
   128  			results <- roundTrip(ctx, tclient, expected)
   129  		}(tc)
   130  	}
   131  
   132  	for i := 0; i < len(testCases); {
   133  		result := <-results
   134  		if result.err != nil {
   135  			t.Fatalf("(%s): %v", result.name, result.err)
   136  		}
   137  		equal, err := protoEqual(result.received, result.expected)
   138  		if err != nil {
   139  			t.Fatalf("failed to compare %s and %s: %s", result.received, result.expected, err)
   140  		}
   141  		if !equal {
   142  			t.Fatalf("unexpected response: %+#v != %+#v", result.received, result.expected)
   143  		}
   144  		i++
   145  	}
   146  }
   147  
   148  func TestServerUnimplemented(t *testing.T) {
   149  	var (
   150  		ctx             = context.Background()
   151  		server          = mustServer(t)(NewServer())
   152  		addr, listener  = newTestListener(t)
   153  		errs            = make(chan error, 1)
   154  		client, cleanup = newTestClient(t, addr)
   155  	)
   156  	defer cleanup()
   157  	defer listener.Close()
   158  	go func() {
   159  		errs <- server.Serve(ctx, listener)
   160  	}()
   161  
   162  	var tp internal.TestPayload
   163  	if err := client.Call(ctx, "Not", "Found", &tp, &tp); err == nil {
   164  		t.Fatalf("expected error from non-existent service call")
   165  	} else if status, ok := status.FromError(err); !ok {
   166  		t.Fatalf("expected status present in error: %v", err)
   167  	} else if status.Code() != codes.Unimplemented {
   168  		t.Fatalf("expected not found for method")
   169  	}
   170  
   171  	if err := server.Shutdown(ctx); err != nil {
   172  		t.Fatal(err)
   173  	}
   174  	if err := <-errs; err != ErrServerClosed {
   175  		t.Fatal(err)
   176  	}
   177  }
   178  
   179  func TestServerListenerClosed(t *testing.T) {
   180  	var (
   181  		ctx         = context.Background()
   182  		server      = mustServer(t)(NewServer())
   183  		_, listener = newTestListener(t)
   184  		errs        = make(chan error, 1)
   185  	)
   186  
   187  	go func() {
   188  		errs <- server.Serve(ctx, listener)
   189  	}()
   190  
   191  	if err := listener.Close(); err != nil {
   192  		t.Fatal(err)
   193  	}
   194  
   195  	err := <-errs
   196  	if err == nil {
   197  		t.Fatal(err)
   198  	}
   199  }
   200  
   201  func TestServerShutdown(t *testing.T) {
   202  	const ncalls = 5
   203  	var (
   204  		ctx              = context.Background()
   205  		server           = mustServer(t)(NewServer())
   206  		addr, listener   = newTestListener(t)
   207  		shutdownStarted  = make(chan struct{})
   208  		shutdownFinished = make(chan struct{})
   209  		handlersStarted  sync.WaitGroup
   210  		proceed          = make(chan struct{})
   211  		serveErrs        = make(chan error, 1)
   212  		callErrs         = make(chan error, ncalls)
   213  		shutdownErrs     = make(chan error, 1)
   214  		client, cleanup  = newTestClient(t, addr)
   215  		_, cleanup2      = newTestClient(t, addr) // secondary connection
   216  	)
   217  	defer cleanup()
   218  	defer cleanup2()
   219  
   220  	// register a service that takes until we tell it to stop
   221  	server.Register(serviceName, map[string]Method{
   222  		"Test": func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
   223  			var req internal.TestPayload
   224  			if err := unmarshal(&req); err != nil {
   225  				return nil, err
   226  			}
   227  
   228  			handlersStarted.Done()
   229  			<-proceed
   230  			return &internal.TestPayload{Foo: "waited"}, nil
   231  		},
   232  	})
   233  
   234  	go func() {
   235  		serveErrs <- server.Serve(ctx, listener)
   236  	}()
   237  
   238  	// send a series of requests that will get blocked
   239  	for i := 0; i < ncalls; i++ {
   240  		handlersStarted.Add(1)
   241  		go func(i int) {
   242  			tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)}
   243  			callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp)
   244  		}(i)
   245  	}
   246  
   247  	handlersStarted.Wait()
   248  	go func() {
   249  		close(shutdownStarted)
   250  		shutdownErrs <- server.Shutdown(ctx)
   251  		close(shutdownFinished)
   252  	}()
   253  
   254  	<-shutdownStarted
   255  	close(proceed)
   256  	<-shutdownFinished
   257  
   258  	for i := 0; i < ncalls; i++ {
   259  		if err := <-callErrs; err != nil && err != ErrClosed {
   260  			t.Fatal(err)
   261  		}
   262  	}
   263  
   264  	if err := <-shutdownErrs; err != nil {
   265  		t.Fatal(err)
   266  	}
   267  
   268  	if err := <-serveErrs; err != ErrServerClosed {
   269  		t.Fatal(err)
   270  	}
   271  	checkServerShutdown(t, server)
   272  }
   273  
   274  func TestServerClose(t *testing.T) {
   275  	var (
   276  		ctx         = context.Background()
   277  		server      = mustServer(t)(NewServer())
   278  		_, listener = newTestListener(t)
   279  		startClose  = make(chan struct{})
   280  		errs        = make(chan error, 1)
   281  	)
   282  
   283  	go func() {
   284  		close(startClose)
   285  		errs <- server.Serve(ctx, listener)
   286  	}()
   287  
   288  	<-startClose
   289  	if err := server.Close(); err != nil {
   290  		t.Fatal(err)
   291  	}
   292  
   293  	err := <-errs
   294  	if err != ErrServerClosed {
   295  		t.Fatal("expected an error from a closed server", err)
   296  	}
   297  
   298  	checkServerShutdown(t, server)
   299  }
   300  
   301  func TestOversizeCall(t *testing.T) {
   302  	var (
   303  		ctx             = context.Background()
   304  		server          = mustServer(t)(NewServer())
   305  		addr, listener  = newTestListener(t)
   306  		errs            = make(chan error, 1)
   307  		client, cleanup = newTestClient(t, addr)
   308  	)
   309  	defer cleanup()
   310  	defer listener.Close()
   311  	go func() {
   312  		errs <- server.Serve(ctx, listener)
   313  	}()
   314  
   315  	registerTestingService(server, &testingServer{})
   316  
   317  	tp := &internal.TestPayload{
   318  		Foo: strings.Repeat("a", 1+messageLengthMax),
   319  	}
   320  	if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
   321  		t.Fatalf("expected error from non-existent service call")
   322  	} else if status, ok := status.FromError(err); !ok {
   323  		t.Fatalf("expected status present in error: %v", err)
   324  	} else if status.Code() != codes.ResourceExhausted {
   325  		t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
   326  	}
   327  
   328  	if err := server.Shutdown(ctx); err != nil {
   329  		t.Fatal(err)
   330  	}
   331  	if err := <-errs; err != ErrServerClosed {
   332  		t.Fatal(err)
   333  	}
   334  }
   335  
   336  func TestClientEOF(t *testing.T) {
   337  	var (
   338  		ctx             = context.Background()
   339  		server          = mustServer(t)(NewServer())
   340  		addr, listener  = newTestListener(t)
   341  		errs            = make(chan error, 1)
   342  		client, cleanup = newTestClient(t, addr)
   343  	)
   344  	defer cleanup()
   345  	defer listener.Close()
   346  	go func() {
   347  		errs <- server.Serve(ctx, listener)
   348  	}()
   349  
   350  	registerTestingService(server, &testingServer{})
   351  
   352  	tp := &internal.TestPayload{}
   353  	// do a regular call
   354  	if err := client.Call(ctx, serviceName, "Test", tp, tp); err != nil {
   355  		t.Fatalf("unexpected error: %v", err)
   356  	}
   357  
   358  	// shutdown the server so the client stops receiving stuff.
   359  	if err := server.Close(); err != nil {
   360  		t.Fatal(err)
   361  	}
   362  	if err := <-errs; err != ErrServerClosed {
   363  		t.Fatal(err)
   364  	}
   365  
   366  	// server shutdown, but we still make a call.
   367  	if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
   368  		t.Fatalf("expected error when calling against shutdown server")
   369  	} else if !errors.Is(err, ErrClosed) {
   370  		var errno syscall.Errno
   371  		if errors.As(err, &errno) {
   372  			t.Logf("errno=%d", errno)
   373  		}
   374  
   375  		t.Fatalf("expected to have a cause of ErrClosed, got %v", err)
   376  	}
   377  }
   378  
   379  func TestServerRequestTimeout(t *testing.T) {
   380  	var (
   381  		ctx, cancel     = context.WithDeadline(context.Background(), time.Now().Add(10*time.Minute))
   382  		server          = mustServer(t)(NewServer())
   383  		addr, listener  = newTestListener(t)
   384  		testImpl        = &testingServer{}
   385  		client, cleanup = newTestClient(t, addr)
   386  		result          internal.TestPayload
   387  	)
   388  	defer cancel()
   389  	defer cleanup()
   390  	defer listener.Close()
   391  
   392  	registerTestingService(server, testImpl)
   393  
   394  	go server.Serve(ctx, listener)
   395  	defer server.Shutdown(ctx)
   396  
   397  	if err := client.Call(ctx, serviceName, "Test", &internal.TestPayload{}, &result); err != nil {
   398  		t.Fatalf("unexpected error making call: %v", err)
   399  	}
   400  
   401  	dl, _ := ctx.Deadline()
   402  	if result.Deadline != dl.UnixNano() {
   403  		t.Fatalf("expected deadline %v, actual: %v", dl, time.Unix(0, result.Deadline))
   404  	}
   405  }
   406  
   407  func TestServerConnectionsLeak(t *testing.T) {
   408  	var (
   409  		ctx             = context.Background()
   410  		server          = mustServer(t)(NewServer())
   411  		addr, listener  = newTestListener(t)
   412  		client, cleanup = newTestClient(t, addr)
   413  	)
   414  	defer cleanup()
   415  	defer listener.Close()
   416  
   417  	connectionCountBefore := server.countConnection()
   418  
   419  	go server.Serve(ctx, listener)
   420  
   421  	registerTestingService(server, &testingServer{})
   422  
   423  	tp := &internal.TestPayload{}
   424  	// do a regular call
   425  	if err := client.Call(ctx, serviceName, "Test", tp, tp); err != nil {
   426  		t.Fatalf("unexpected error during test call: %v", err)
   427  	}
   428  
   429  	connectionCount := server.countConnection()
   430  	if connectionCount != 1 {
   431  		t.Fatalf("unexpected connection count: %d, expected: %d", connectionCount, 1)
   432  	}
   433  
   434  	// close the client, so that server gets EOF
   435  	if err := client.Close(); err != nil {
   436  		t.Fatalf("unexpected error while closing client: %v", err)
   437  	}
   438  
   439  	// server should eventually close the client connection
   440  	maxAttempts := 20
   441  	for i := 1; i <= maxAttempts; i++ {
   442  		connectionCountAfter := server.countConnection()
   443  		if connectionCountAfter == connectionCountBefore {
   444  			break
   445  		}
   446  		if i == maxAttempts {
   447  			t.Fatalf("expected number of connections to be equal %d after client close, got %d connections",
   448  				connectionCountBefore, connectionCountAfter)
   449  		}
   450  		time.Sleep(100 * time.Millisecond)
   451  	}
   452  }
   453  
   454  func BenchmarkRoundTrip(b *testing.B) {
   455  	var (
   456  		ctx             = context.Background()
   457  		server          = mustServer(b)(NewServer())
   458  		testImpl        = &testingServer{}
   459  		addr, listener  = newTestListener(b)
   460  		client, cleanup = newTestClient(b, addr)
   461  		tclient         = newTestingClient(client)
   462  	)
   463  
   464  	defer listener.Close()
   465  	defer cleanup()
   466  
   467  	registerTestingService(server, testImpl)
   468  
   469  	go server.Serve(ctx, listener)
   470  	defer server.Shutdown(ctx)
   471  
   472  	var tp internal.TestPayload
   473  	b.ResetTimer()
   474  
   475  	for i := 0; i < b.N; i++ {
   476  		if _, err := tclient.Test(ctx, &tp); err != nil {
   477  			b.Fatal(err)
   478  		}
   479  	}
   480  }
   481  
   482  func checkServerShutdown(t *testing.T, server *Server) {
   483  	t.Helper()
   484  	server.mu.Lock()
   485  	defer server.mu.Unlock()
   486  
   487  	if len(server.listeners) > 0 {
   488  		t.Errorf("expected listeners to be empty: %v", server.listeners)
   489  	}
   490  	for listener := range server.listeners {
   491  		t.Logf("listener addr=%s", listener.Addr())
   492  	}
   493  
   494  	if len(server.connections) > 0 {
   495  		t.Errorf("expected connections to be empty: %v", server.connections)
   496  	}
   497  	for conn := range server.connections {
   498  		state, ok := conn.getState()
   499  		if !ok {
   500  			t.Errorf("failed to get state from %v", conn)
   501  		}
   502  		t.Logf("conn state=%s", state)
   503  	}
   504  }
   505  
   506  type callResult struct {
   507  	name     string
   508  	err      error
   509  	input    *internal.TestPayload
   510  	expected *internal.TestPayload
   511  	received *internal.TestPayload
   512  }
   513  
   514  func roundTrip(ctx context.Context, client *testingClient, name string) callResult {
   515  	var (
   516  		tp = &internal.TestPayload{
   517  			Foo: name,
   518  		}
   519  	)
   520  
   521  	ctx = WithMetadata(ctx, MD{"foo": []string{name}})
   522  
   523  	resp, err := client.Test(ctx, tp)
   524  	if err != nil {
   525  		return callResult{
   526  			name: name,
   527  			err:  err,
   528  		}
   529  	}
   530  
   531  	return callResult{
   532  		name:     name,
   533  		input:    tp,
   534  		expected: &internal.TestPayload{Foo: strings.Repeat(tp.Foo, 2), Metadata: name},
   535  		received: resp,
   536  	}
   537  }
   538  
   539  func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func()) {
   540  	conn, err := net.Dial("unix", addr)
   541  	if err != nil {
   542  		t.Fatal(err)
   543  	}
   544  	client := NewClient(conn, opts...)
   545  	return client, func() {
   546  		conn.Close()
   547  		client.Close()
   548  	}
   549  }
   550  
   551  func newTestListener(t testing.TB) (string, net.Listener) {
   552  	var prefix string
   553  
   554  	// Abstracts sockets are only available on Linux.
   555  	if runtime.GOOS == "linux" {
   556  		prefix = "\x00"
   557  	}
   558  	addr := prefix + t.Name()
   559  	listener, err := net.Listen("unix", addr)
   560  	if err != nil {
   561  		t.Fatal(err)
   562  	}
   563  
   564  	return addr, listener
   565  }
   566  
   567  func mustServer(t testing.TB) func(server *Server, err error) *Server {
   568  	return func(server *Server, err error) *Server {
   569  		t.Helper()
   570  		if err != nil {
   571  			t.Fatal(err)
   572  		}
   573  
   574  		return server
   575  	}
   576  }
   577  

View as plain text