...

Source file src/google.golang.org/grpc/experimental/shared_buffer_pool_test.go

Documentation: google.golang.org/grpc/experimental

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package experimental_test
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"io"
    25  	"testing"
    26  	"time"
    27  
    28  	"google.golang.org/grpc"
    29  	"google.golang.org/grpc/encoding/gzip"
    30  	"google.golang.org/grpc/experimental"
    31  	"google.golang.org/grpc/internal/grpctest"
    32  	"google.golang.org/grpc/internal/stubserver"
    33  
    34  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    35  	testpb "google.golang.org/grpc/interop/grpc_testing"
    36  )
    37  
    38  type s struct {
    39  	grpctest.Tester
    40  }
    41  
    42  func Test(t *testing.T) {
    43  	grpctest.RunSubTests(t, s{})
    44  }
    45  
    46  const defaultTestTimeout = 10 * time.Second
    47  
    48  func (s) TestRecvBufferPoolStream(t *testing.T) {
    49  	tcs := []struct {
    50  		name     string
    51  		callOpts []grpc.CallOption
    52  	}{
    53  		{
    54  			name: "default",
    55  		},
    56  		{
    57  			name: "useCompressor",
    58  			callOpts: []grpc.CallOption{
    59  				grpc.UseCompressor(gzip.Name),
    60  			},
    61  		},
    62  	}
    63  
    64  	for _, tc := range tcs {
    65  		t.Run(tc.name, func(t *testing.T) {
    66  			const reqCount = 10
    67  
    68  			ss := &stubserver.StubServer{
    69  				FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
    70  					for i := 0; i < reqCount; i++ {
    71  						preparedMsg := &grpc.PreparedMsg{}
    72  						if err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{
    73  							Payload: &testpb.Payload{
    74  								Body: []byte{'0' + uint8(i)},
    75  							},
    76  						}); err != nil {
    77  							return err
    78  						}
    79  						stream.SendMsg(preparedMsg)
    80  					}
    81  					return nil
    82  				},
    83  			}
    84  
    85  			pool := &checkBufferPool{}
    86  			sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
    87  			dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
    88  			if err := ss.Start(sopts, dopts...); err != nil {
    89  				t.Fatalf("Error starting endpoint server: %v", err)
    90  			}
    91  			defer ss.Stop()
    92  
    93  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
    94  			defer cancel()
    95  
    96  			stream, err := ss.Client.FullDuplexCall(ctx, tc.callOpts...)
    97  			if err != nil {
    98  				t.Fatalf("ss.Client.FullDuplexCall failed: %v", err)
    99  			}
   100  
   101  			var ngot int
   102  			var buf bytes.Buffer
   103  			for {
   104  				reply, err := stream.Recv()
   105  				if err == io.EOF {
   106  					break
   107  				}
   108  				if err != nil {
   109  					t.Fatal(err)
   110  				}
   111  				ngot++
   112  				if buf.Len() > 0 {
   113  					buf.WriteByte(',')
   114  				}
   115  				buf.Write(reply.GetPayload().GetBody())
   116  			}
   117  			if want := 10; ngot != want {
   118  				t.Fatalf("Got %d replies, want %d", ngot, want)
   119  			}
   120  			if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
   121  				t.Fatalf("Got replies %q; want %q", got, want)
   122  			}
   123  
   124  			if len(pool.puts) != reqCount {
   125  				t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts))
   126  			}
   127  		})
   128  	}
   129  }
   130  
   131  func (s) TestRecvBufferPoolUnary(t *testing.T) {
   132  	tcs := []struct {
   133  		name     string
   134  		callOpts []grpc.CallOption
   135  	}{
   136  		{
   137  			name: "default",
   138  		},
   139  		{
   140  			name: "useCompressor",
   141  			callOpts: []grpc.CallOption{
   142  				grpc.UseCompressor(gzip.Name),
   143  			},
   144  		},
   145  	}
   146  
   147  	for _, tc := range tcs {
   148  		t.Run(tc.name, func(t *testing.T) {
   149  			const largeSize = 1024
   150  
   151  			ss := &stubserver.StubServer{
   152  				UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   153  					return &testpb.SimpleResponse{
   154  						Payload: &testpb.Payload{
   155  							Body: make([]byte, largeSize),
   156  						},
   157  					}, nil
   158  				},
   159  			}
   160  
   161  			pool := &checkBufferPool{}
   162  			sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
   163  			dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
   164  			if err := ss.Start(sopts, dopts...); err != nil {
   165  				t.Fatalf("Error starting endpoint server: %v", err)
   166  			}
   167  			defer ss.Stop()
   168  
   169  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   170  			defer cancel()
   171  
   172  			const reqCount = 10
   173  			for i := 0; i < reqCount; i++ {
   174  				if _, err := ss.Client.UnaryCall(
   175  					ctx,
   176  					&testpb.SimpleRequest{
   177  						Payload: &testpb.Payload{
   178  							Body: make([]byte, largeSize),
   179  						},
   180  					},
   181  					tc.callOpts...,
   182  				); err != nil {
   183  					t.Fatalf("ss.Client.UnaryCall failed: %v", err)
   184  				}
   185  			}
   186  
   187  			const bufferCount = reqCount * 2 // req + resp
   188  			if len(pool.puts) != bufferCount {
   189  				t.Fatalf("Expected %d buffers to be returned to the pool, got %d", bufferCount, len(pool.puts))
   190  			}
   191  		})
   192  	}
   193  }
   194  
   195  type checkBufferPool struct {
   196  	puts [][]byte
   197  }
   198  
   199  func (p *checkBufferPool) Get(size int) []byte {
   200  	return make([]byte, size)
   201  }
   202  
   203  func (p *checkBufferPool) Put(bs *[]byte) {
   204  	p.puts = append(p.puts, *bs)
   205  }
   206  

View as plain text