...

Source file src/google.golang.org/grpc/test/compressor_test.go

Documentation: google.golang.org/grpc/test

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package test
    20  
    21  import (
    22  	"bytes"
    23  	"compress/gzip"
    24  	"context"
    25  	"io"
    26  	"reflect"
    27  	"strings"
    28  	"sync/atomic"
    29  	"testing"
    30  
    31  	"google.golang.org/grpc"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/encoding"
    34  	"google.golang.org/grpc/internal/stubserver"
    35  	"google.golang.org/grpc/metadata"
    36  	"google.golang.org/grpc/status"
    37  
    38  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    39  	testpb "google.golang.org/grpc/interop/grpc_testing"
    40  )
    41  
    42  func (s) TestCompressServerHasNoSupport(t *testing.T) {
    43  	for _, e := range listTestEnv() {
    44  		testCompressServerHasNoSupport(t, e)
    45  	}
    46  }
    47  
    48  func testCompressServerHasNoSupport(t *testing.T, e env) {
    49  	te := newTest(t, e)
    50  	te.serverCompression = false
    51  	te.clientCompression = false
    52  	te.clientNopCompression = true
    53  	te.startServer(&testServer{security: e.security})
    54  	defer te.tearDown()
    55  	tc := testgrpc.NewTestServiceClient(te.clientConn())
    56  
    57  	const argSize = 271828
    58  	const respSize = 314159
    59  	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
    60  	if err != nil {
    61  		t.Fatal(err)
    62  	}
    63  	req := &testpb.SimpleRequest{
    64  		ResponseType: testpb.PayloadType_COMPRESSABLE,
    65  		ResponseSize: respSize,
    66  		Payload:      payload,
    67  	}
    68  
    69  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
    70  	defer cancel()
    71  	if _, err := tc.UnaryCall(ctx, req); err == nil || status.Code(err) != codes.Unimplemented {
    72  		t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code %s", err, codes.Unimplemented)
    73  	}
    74  	// Streaming RPC
    75  	stream, err := tc.FullDuplexCall(ctx)
    76  	if err != nil {
    77  		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
    78  	}
    79  	if _, err := stream.Recv(); err == nil || status.Code(err) != codes.Unimplemented {
    80  		t.Fatalf("%v.Recv() = %v, want error code %s", stream, err, codes.Unimplemented)
    81  	}
    82  }
    83  
    84  func (s) TestCompressOK(t *testing.T) {
    85  	for _, e := range listTestEnv() {
    86  		testCompressOK(t, e)
    87  	}
    88  }
    89  
    90  func testCompressOK(t *testing.T, e env) {
    91  	te := newTest(t, e)
    92  	te.serverCompression = true
    93  	te.clientCompression = true
    94  	te.startServer(&testServer{security: e.security})
    95  	defer te.tearDown()
    96  	tc := testgrpc.NewTestServiceClient(te.clientConn())
    97  
    98  	// Unary call
    99  	const argSize = 271828
   100  	const respSize = 314159
   101  	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
   102  	if err != nil {
   103  		t.Fatal(err)
   104  	}
   105  	req := &testpb.SimpleRequest{
   106  		ResponseType: testpb.PayloadType_COMPRESSABLE,
   107  		ResponseSize: respSize,
   108  		Payload:      payload,
   109  	}
   110  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   111  	defer cancel()
   112  	ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("something", "something"))
   113  	if _, err := tc.UnaryCall(ctx, req); err != nil {
   114  		t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
   115  	}
   116  	// Streaming RPC
   117  	stream, err := tc.FullDuplexCall(ctx)
   118  	if err != nil {
   119  		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
   120  	}
   121  	respParam := []*testpb.ResponseParameters{
   122  		{
   123  			Size: 31415,
   124  		},
   125  	}
   126  	payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
   127  	if err != nil {
   128  		t.Fatal(err)
   129  	}
   130  	sreq := &testpb.StreamingOutputCallRequest{
   131  		ResponseType:       testpb.PayloadType_COMPRESSABLE,
   132  		ResponseParameters: respParam,
   133  		Payload:            payload,
   134  	}
   135  	if err := stream.Send(sreq); err != nil {
   136  		t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
   137  	}
   138  	stream.CloseSend()
   139  	if _, err := stream.Recv(); err != nil {
   140  		t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
   141  	}
   142  	if _, err := stream.Recv(); err != io.EOF {
   143  		t.Fatalf("%v.Recv() = %v, want io.EOF", stream, err)
   144  	}
   145  }
   146  
   147  func (s) TestIdentityEncoding(t *testing.T) {
   148  	for _, e := range listTestEnv() {
   149  		testIdentityEncoding(t, e)
   150  	}
   151  }
   152  
   153  func testIdentityEncoding(t *testing.T, e env) {
   154  	te := newTest(t, e)
   155  	te.startServer(&testServer{security: e.security})
   156  	defer te.tearDown()
   157  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   158  
   159  	// Unary call
   160  	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 5)
   161  	if err != nil {
   162  		t.Fatal(err)
   163  	}
   164  	req := &testpb.SimpleRequest{
   165  		ResponseType: testpb.PayloadType_COMPRESSABLE,
   166  		ResponseSize: 10,
   167  		Payload:      payload,
   168  	}
   169  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   170  	defer cancel()
   171  	ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("something", "something"))
   172  	if _, err := tc.UnaryCall(ctx, req); err != nil {
   173  		t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
   174  	}
   175  	// Streaming RPC
   176  	stream, err := tc.FullDuplexCall(ctx, grpc.UseCompressor("identity"))
   177  	if err != nil {
   178  		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
   179  	}
   180  	payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
   181  	if err != nil {
   182  		t.Fatal(err)
   183  	}
   184  	sreq := &testpb.StreamingOutputCallRequest{
   185  		ResponseType:       testpb.PayloadType_COMPRESSABLE,
   186  		ResponseParameters: []*testpb.ResponseParameters{{Size: 10}},
   187  		Payload:            payload,
   188  	}
   189  	if err := stream.Send(sreq); err != nil {
   190  		t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
   191  	}
   192  	stream.CloseSend()
   193  	if _, err := stream.Recv(); err != nil {
   194  		t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
   195  	}
   196  	if _, err := stream.Recv(); err != io.EOF {
   197  		t.Fatalf("%v.Recv() = %v, want io.EOF", stream, err)
   198  	}
   199  }
   200  
   201  // renameCompressor is a grpc.Compressor wrapper that allows customizing the
   202  // Type() of another compressor.
   203  type renameCompressor struct {
   204  	grpc.Compressor
   205  	name string
   206  }
   207  
   208  func (r *renameCompressor) Type() string { return r.name }
   209  
   210  // renameDecompressor is a grpc.Decompressor wrapper that allows customizing the
   211  // Type() of another Decompressor.
   212  type renameDecompressor struct {
   213  	grpc.Decompressor
   214  	name string
   215  }
   216  
   217  func (r *renameDecompressor) Type() string { return r.name }
   218  
   219  func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) {
   220  	wantGrpcAcceptEncodingCh := make(chan []string, 1)
   221  	defer close(wantGrpcAcceptEncodingCh)
   222  
   223  	compressor := renameCompressor{Compressor: grpc.NewGZIPCompressor(), name: "testgzip"}
   224  	decompressor := renameDecompressor{Decompressor: grpc.NewGZIPDecompressor(), name: "testgzip"}
   225  
   226  	ss := &stubserver.StubServer{
   227  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
   228  			md, ok := metadata.FromIncomingContext(ctx)
   229  			if !ok {
   230  				return nil, status.Errorf(codes.Internal, "no metadata in context")
   231  			}
   232  			if got, want := md["grpc-accept-encoding"], <-wantGrpcAcceptEncodingCh; !reflect.DeepEqual(got, want) {
   233  				return nil, status.Errorf(codes.Internal, "got grpc-accept-encoding=%q; want [%q]", got, want)
   234  			}
   235  			return &testpb.Empty{}, nil
   236  		},
   237  	}
   238  	if err := ss.Start([]grpc.ServerOption{grpc.RPCDecompressor(&decompressor)}); err != nil {
   239  		t.Fatalf("Error starting endpoint server: %v", err)
   240  	}
   241  	defer ss.Stop()
   242  
   243  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   244  	defer cancel()
   245  
   246  	wantGrpcAcceptEncodingCh <- []string{"gzip"}
   247  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   248  		t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
   249  	}
   250  
   251  	wantGrpcAcceptEncodingCh <- []string{"gzip"}
   252  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.UseCompressor("gzip")); err != nil {
   253  		t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
   254  	}
   255  
   256  	// Use compressor directly which is not registered via
   257  	// encoding.RegisterCompressor.
   258  	if err := ss.StartClient(grpc.WithCompressor(&compressor)); err != nil {
   259  		t.Fatalf("Error starting client: %v", err)
   260  	}
   261  	wantGrpcAcceptEncodingCh <- []string{"gzip,testgzip"}
   262  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   263  		t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
   264  	}
   265  }
   266  
   267  // wrapCompressor is a wrapper of encoding.Compressor which maintains count of
   268  // Compressor method invokes.
   269  type wrapCompressor struct {
   270  	encoding.Compressor
   271  	compressInvokes int32
   272  }
   273  
   274  func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
   275  	atomic.AddInt32(&wc.compressInvokes, 1)
   276  	return wc.Compressor.Compress(w)
   277  }
   278  
   279  func setupGzipWrapCompressor(t *testing.T) *wrapCompressor {
   280  	oldC := encoding.GetCompressor("gzip")
   281  	c := &wrapCompressor{Compressor: oldC}
   282  	encoding.RegisterCompressor(c)
   283  	t.Cleanup(func() {
   284  		encoding.RegisterCompressor(oldC)
   285  	})
   286  	return c
   287  }
   288  
   289  func (s) TestSetSendCompressorSuccess(t *testing.T) {
   290  	for _, tt := range []struct {
   291  		name                string
   292  		desc                string
   293  		payload             *testpb.Payload
   294  		dialOpts            []grpc.DialOption
   295  		resCompressor       string
   296  		wantCompressInvokes int32
   297  	}{
   298  		{
   299  			name:                "identity_request_and_gzip_response",
   300  			desc:                "request is uncompressed and response is gzip compressed",
   301  			payload:             &testpb.Payload{Body: []byte("payload")},
   302  			resCompressor:       "gzip",
   303  			wantCompressInvokes: 1,
   304  		},
   305  		{
   306  			name:                "identity_request_and_empty_response",
   307  			desc:                "request is uncompressed and response is gzip compressed",
   308  			payload:             nil,
   309  			resCompressor:       "gzip",
   310  			wantCompressInvokes: 0,
   311  		},
   312  		{
   313  			name:          "gzip_request_and_identity_response",
   314  			desc:          "request is gzip compressed and response is uncompressed with identity",
   315  			payload:       &testpb.Payload{Body: []byte("payload")},
   316  			resCompressor: "identity",
   317  			dialOpts: []grpc.DialOption{
   318  				// Use WithCompressor instead of UseCompressor to avoid counting
   319  				// the client's compressor usage.
   320  				grpc.WithCompressor(grpc.NewGZIPCompressor()),
   321  			},
   322  			wantCompressInvokes: 0,
   323  		},
   324  	} {
   325  		t.Run(tt.name, func(t *testing.T) {
   326  			t.Run("unary", func(t *testing.T) {
   327  				testUnarySetSendCompressorSuccess(t, tt.payload, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
   328  			})
   329  
   330  			t.Run("stream", func(t *testing.T) {
   331  				testStreamSetSendCompressorSuccess(t, tt.payload, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
   332  			})
   333  		})
   334  	}
   335  }
   336  
   337  func testUnarySetSendCompressorSuccess(t *testing.T, payload *testpb.Payload, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
   338  	wc := setupGzipWrapCompressor(t)
   339  	ss := &stubserver.StubServer{
   340  		UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   341  			if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil {
   342  				return nil, err
   343  			}
   344  			return &testpb.SimpleResponse{
   345  				Payload: payload,
   346  			}, nil
   347  		},
   348  	}
   349  	if err := ss.Start(nil, dialOpts...); err != nil {
   350  		t.Fatalf("Error starting endpoint server: %v", err)
   351  	}
   352  	defer ss.Stop()
   353  
   354  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   355  	defer cancel()
   356  
   357  	if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
   358  		t.Fatalf("Unexpected unary call error, got: %v, want: nil", err)
   359  	}
   360  
   361  	compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
   362  	if compressInvokes != wantCompressInvokes {
   363  		t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
   364  	}
   365  }
   366  
   367  func testStreamSetSendCompressorSuccess(t *testing.T, payload *testpb.Payload, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
   368  	wc := setupGzipWrapCompressor(t)
   369  	ss := &stubserver.StubServer{
   370  		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
   371  			if _, err := stream.Recv(); err != nil {
   372  				return err
   373  			}
   374  
   375  			if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil {
   376  				return err
   377  			}
   378  
   379  			return stream.Send(&testpb.StreamingOutputCallResponse{
   380  				Payload: payload,
   381  			})
   382  		},
   383  	}
   384  	if err := ss.Start(nil, dialOpts...); err != nil {
   385  		t.Fatalf("Error starting endpoint server: %v", err)
   386  	}
   387  	defer ss.Stop()
   388  
   389  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   390  	defer cancel()
   391  
   392  	s, err := ss.Client.FullDuplexCall(ctx)
   393  	if err != nil {
   394  		t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
   395  	}
   396  
   397  	if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
   398  		t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err)
   399  	}
   400  
   401  	if _, err := s.Recv(); err != nil {
   402  		t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err)
   403  	}
   404  
   405  	compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
   406  	if compressInvokes != wantCompressInvokes {
   407  		t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
   408  	}
   409  }
   410  
   411  func (s) TestUnregisteredSetSendCompressorFailure(t *testing.T) {
   412  	resCompressor := "snappy2"
   413  	wantErr := status.Error(codes.Unknown, "unable to set send compressor: compressor not registered \"snappy2\"")
   414  
   415  	t.Run("unary", func(t *testing.T) {
   416  		testUnarySetSendCompressorFailure(t, resCompressor, wantErr)
   417  	})
   418  
   419  	t.Run("stream", func(t *testing.T) {
   420  		testStreamSetSendCompressorFailure(t, resCompressor, wantErr)
   421  	})
   422  }
   423  
   424  func testUnarySetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) {
   425  	ss := &stubserver.StubServer{
   426  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
   427  			if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil {
   428  				return nil, err
   429  			}
   430  			return &testpb.Empty{}, nil
   431  		},
   432  	}
   433  	if err := ss.Start(nil); err != nil {
   434  		t.Fatalf("Error starting endpoint server: %v", err)
   435  	}
   436  	defer ss.Stop()
   437  
   438  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   439  	defer cancel()
   440  
   441  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) {
   442  		t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr)
   443  	}
   444  }
   445  
   446  func testStreamSetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) {
   447  	ss := &stubserver.StubServer{
   448  		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
   449  			if _, err := stream.Recv(); err != nil {
   450  				return err
   451  			}
   452  
   453  			if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil {
   454  				return err
   455  			}
   456  
   457  			return stream.Send(&testpb.StreamingOutputCallResponse{})
   458  		},
   459  	}
   460  	if err := ss.Start(nil); err != nil {
   461  		t.Fatalf("Error starting endpoint server: %v, want: nil", err)
   462  	}
   463  	defer ss.Stop()
   464  
   465  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   466  	defer cancel()
   467  
   468  	s, err := ss.Client.FullDuplexCall(ctx)
   469  	if err != nil {
   470  		t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
   471  	}
   472  
   473  	if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
   474  		t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err)
   475  	}
   476  
   477  	if _, err := s.Recv(); !equalError(err, wantErr) {
   478  		t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err)
   479  	}
   480  }
   481  
   482  func (s) TestUnarySetSendCompressorAfterHeaderSendFailure(t *testing.T) {
   483  	ss := &stubserver.StubServer{
   484  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
   485  			// Send headers early and then set send compressor.
   486  			grpc.SendHeader(ctx, metadata.MD{})
   487  			err := grpc.SetSendCompressor(ctx, "gzip")
   488  			if err == nil {
   489  				t.Error("Wanted set send compressor error")
   490  				return &testpb.Empty{}, nil
   491  			}
   492  			return nil, err
   493  		},
   494  	}
   495  	if err := ss.Start(nil); err != nil {
   496  		t.Fatalf("Error starting endpoint server: %v", err)
   497  	}
   498  	defer ss.Stop()
   499  
   500  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   501  	defer cancel()
   502  
   503  	wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done")
   504  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) {
   505  		t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr)
   506  	}
   507  }
   508  
   509  func (s) TestStreamSetSendCompressorAfterHeaderSendFailure(t *testing.T) {
   510  	ss := &stubserver.StubServer{
   511  		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
   512  			// Send headers early and then set send compressor.
   513  			grpc.SendHeader(stream.Context(), metadata.MD{})
   514  			err := grpc.SetSendCompressor(stream.Context(), "gzip")
   515  			if err == nil {
   516  				t.Error("Wanted set send compressor error")
   517  			}
   518  			return err
   519  		},
   520  	}
   521  	if err := ss.Start(nil); err != nil {
   522  		t.Fatalf("Error starting endpoint server: %v", err)
   523  	}
   524  	defer ss.Stop()
   525  
   526  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   527  	defer cancel()
   528  
   529  	wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done")
   530  	s, err := ss.Client.FullDuplexCall(ctx)
   531  	if err != nil {
   532  		t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
   533  	}
   534  
   535  	if _, err := s.Recv(); !equalError(err, wantErr) {
   536  		t.Fatalf("Unexpected full duplex recv error, got: %v, want: %v", err, wantErr)
   537  	}
   538  }
   539  
   540  func (s) TestClientSupportedCompressors(t *testing.T) {
   541  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   542  	defer cancel()
   543  
   544  	for _, tt := range []struct {
   545  		desc string
   546  		ctx  context.Context
   547  		want []string
   548  	}{
   549  		{
   550  			desc: "No additional grpc-accept-encoding header",
   551  			ctx:  ctx,
   552  			want: []string{"gzip"},
   553  		},
   554  		{
   555  			desc: "With additional grpc-accept-encoding header",
   556  			ctx: metadata.AppendToOutgoingContext(ctx,
   557  				"grpc-accept-encoding", "test-compressor-1",
   558  				"grpc-accept-encoding", "test-compressor-2",
   559  			),
   560  			want: []string{"gzip", "test-compressor-1", "test-compressor-2"},
   561  		},
   562  		{
   563  			desc: "With additional empty grpc-accept-encoding header",
   564  			ctx: metadata.AppendToOutgoingContext(ctx,
   565  				"grpc-accept-encoding", "",
   566  			),
   567  			want: []string{"gzip"},
   568  		},
   569  		{
   570  			desc: "With additional grpc-accept-encoding header with spaces between values",
   571  			ctx: metadata.AppendToOutgoingContext(ctx,
   572  				"grpc-accept-encoding", "identity, deflate",
   573  			),
   574  			want: []string{"gzip", "identity", "deflate"},
   575  		},
   576  	} {
   577  		t.Run(tt.desc, func(t *testing.T) {
   578  			ss := &stubserver.StubServer{
   579  				EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
   580  					got, err := grpc.ClientSupportedCompressors(ctx)
   581  					if err != nil {
   582  						return nil, err
   583  					}
   584  
   585  					if !reflect.DeepEqual(got, tt.want) {
   586  						t.Errorf("unexpected client compressors got: %v, want: %v", got, tt.want)
   587  					}
   588  
   589  					return &testpb.Empty{}, nil
   590  				},
   591  			}
   592  			if err := ss.Start(nil); err != nil {
   593  				t.Fatalf("Error starting endpoint server: %v, want: nil", err)
   594  			}
   595  			defer ss.Stop()
   596  
   597  			_, err := ss.Client.EmptyCall(tt.ctx, &testpb.Empty{})
   598  			if err != nil {
   599  				t.Fatalf("Unexpected unary call error, got: %v, want: nil", err)
   600  			}
   601  		})
   602  	}
   603  }
   604  
   605  func (s) TestCompressorRegister(t *testing.T) {
   606  	for _, e := range listTestEnv() {
   607  		testCompressorRegister(t, e)
   608  	}
   609  }
   610  
   611  func testCompressorRegister(t *testing.T, e env) {
   612  	te := newTest(t, e)
   613  	te.clientCompression = false
   614  	te.serverCompression = false
   615  	te.clientUseCompression = true
   616  
   617  	te.startServer(&testServer{security: e.security})
   618  	defer te.tearDown()
   619  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   620  
   621  	// Unary call
   622  	const argSize = 271828
   623  	const respSize = 314159
   624  	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
   625  	if err != nil {
   626  		t.Fatal(err)
   627  	}
   628  	req := &testpb.SimpleRequest{
   629  		ResponseType: testpb.PayloadType_COMPRESSABLE,
   630  		ResponseSize: respSize,
   631  		Payload:      payload,
   632  	}
   633  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   634  	defer cancel()
   635  	ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("something", "something"))
   636  	if _, err := tc.UnaryCall(ctx, req); err != nil {
   637  		t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
   638  	}
   639  	// Streaming RPC
   640  	stream, err := tc.FullDuplexCall(ctx)
   641  	if err != nil {
   642  		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
   643  	}
   644  	respParam := []*testpb.ResponseParameters{
   645  		{
   646  			Size: 31415,
   647  		},
   648  	}
   649  	payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
   650  	if err != nil {
   651  		t.Fatal(err)
   652  	}
   653  	sreq := &testpb.StreamingOutputCallRequest{
   654  		ResponseType:       testpb.PayloadType_COMPRESSABLE,
   655  		ResponseParameters: respParam,
   656  		Payload:            payload,
   657  	}
   658  	if err := stream.Send(sreq); err != nil {
   659  		t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
   660  	}
   661  	if _, err := stream.Recv(); err != nil {
   662  		t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
   663  	}
   664  }
   665  
   666  type badGzipCompressor struct{}
   667  
   668  func (badGzipCompressor) Do(w io.Writer, p []byte) error {
   669  	buf := &bytes.Buffer{}
   670  	gzw := gzip.NewWriter(buf)
   671  	if _, err := gzw.Write(p); err != nil {
   672  		return err
   673  	}
   674  	err := gzw.Close()
   675  	bs := buf.Bytes()
   676  	if len(bs) >= 6 {
   677  		bs[len(bs)-6] ^= 1 // modify checksum at end by 1 byte
   678  	}
   679  	w.Write(bs)
   680  	return err
   681  }
   682  
   683  func (badGzipCompressor) Type() string {
   684  	return "gzip"
   685  }
   686  
   687  func (s) TestGzipBadChecksum(t *testing.T) {
   688  	ss := &stubserver.StubServer{
   689  		UnaryCallF: func(ctx context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   690  			return &testpb.SimpleResponse{}, nil
   691  		},
   692  	}
   693  	if err := ss.Start(nil, grpc.WithCompressor(badGzipCompressor{})); err != nil {
   694  		t.Fatalf("Error starting endpoint server: %v", err)
   695  	}
   696  	defer ss.Stop()
   697  
   698  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   699  	defer cancel()
   700  
   701  	p, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1024))
   702  	if err != nil {
   703  		t.Fatalf("Unexpected error from newPayload: %v", err)
   704  	}
   705  	if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: p}); err == nil ||
   706  		status.Code(err) != codes.Internal ||
   707  		!strings.Contains(status.Convert(err).Message(), gzip.ErrChecksum.Error()) {
   708  		t.Errorf("ss.Client.UnaryCall(_) = _, %v\n\twant: _, status(codes.Internal, contains %q)", err, gzip.ErrChecksum)
   709  	}
   710  }
   711  

View as plain text