...

Source file src/google.golang.org/grpc/test/interceptor_test.go

Documentation: google.golang.org/grpc/test

     1  /*
     2   *
     3   * Copyright 2022 gRPC authors.
     4  
     5   *
     6   * Licensed under the Apache License, Version 2.0 (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   *     https://www.apache.org/licenses/LICENSE-2.0
    11   *
    12   * Unless required by applicable law or agreed to in writing, software
    13   * distributed under the License is distributed on an "AS IS" BASIS,
    14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15   * See the License for the specific language governing permissions and
    16   * limitations under the License.
    17   *
    18   */
    19  
    20  package test
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"testing"
    26  
    27  	"google.golang.org/grpc"
    28  	"google.golang.org/grpc/internal/stubserver"
    29  	"google.golang.org/grpc/internal/testutils"
    30  
    31  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    32  	testpb "google.golang.org/grpc/interop/grpc_testing"
    33  )
    34  
    35  type parentCtxkey struct{}
    36  type firstInterceptorCtxkey struct{}
    37  type secondInterceptorCtxkey struct{}
    38  type baseInterceptorCtxKey struct{}
    39  
    40  const (
    41  	parentCtxVal            = "parent"
    42  	firstInterceptorCtxVal  = "firstInterceptor"
    43  	secondInterceptorCtxVal = "secondInterceptor"
    44  	baseInterceptorCtxVal   = "baseInterceptor"
    45  )
    46  
    47  // TestUnaryClientInterceptor_ContextValuePropagation verifies that a unary
    48  // interceptor receives context values specified in the context passed to the
    49  // RPC call.
    50  func (s) TestUnaryClientInterceptor_ContextValuePropagation(t *testing.T) {
    51  	errCh := testutils.NewChannel()
    52  	unaryInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    53  		if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
    54  			errCh.Send(fmt.Errorf("unaryInt got %q in context.Val, want %q", got, parentCtxVal))
    55  		}
    56  		errCh.Send(nil)
    57  		return invoker(ctx, method, req, reply, cc, opts...)
    58  	}
    59  
    60  	// Start a stub server and use the above unary interceptor while creating a
    61  	// ClientConn to it.
    62  	ss := &stubserver.StubServer{
    63  		EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil },
    64  	}
    65  	if err := ss.Start(nil, grpc.WithUnaryInterceptor(unaryInt)); err != nil {
    66  		t.Fatalf("Failed to start stub server: %v", err)
    67  	}
    68  	defer ss.Stop()
    69  
    70  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
    71  	defer cancel()
    72  	if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil {
    73  		t.Fatalf("ss.Client.EmptyCall() failed: %v", err)
    74  	}
    75  	val, err := errCh.Receive(ctx)
    76  	if err != nil {
    77  		t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err)
    78  	}
    79  	if val != nil {
    80  		t.Fatalf("unary interceptor failed: %v", val)
    81  	}
    82  }
    83  
    84  // TestChainUnaryClientInterceptor_ContextValuePropagation verifies that a chain
    85  // of unary interceptors receive context values specified in the original call
    86  // as well as the ones specified by prior interceptors in the chain.
    87  func (s) TestChainUnaryClientInterceptor_ContextValuePropagation(t *testing.T) {
    88  	errCh := testutils.NewChannel()
    89  	firstInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    90  		if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
    91  			errCh.SendContext(ctx, fmt.Errorf("first interceptor got %q in context.Val, want %q", got, parentCtxVal))
    92  		}
    93  		if ctx.Value(firstInterceptorCtxkey{}) != nil {
    94  			errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", firstInterceptorCtxkey{}))
    95  		}
    96  		if ctx.Value(secondInterceptorCtxkey{}) != nil {
    97  			errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", secondInterceptorCtxkey{}))
    98  		}
    99  		firstCtx := context.WithValue(ctx, firstInterceptorCtxkey{}, firstInterceptorCtxVal)
   100  		return invoker(firstCtx, method, req, reply, cc, opts...)
   101  	}
   102  
   103  	secondInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   104  		if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
   105  			errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, parentCtxVal))
   106  		}
   107  		if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal {
   108  			errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal))
   109  		}
   110  		if ctx.Value(secondInterceptorCtxkey{}) != nil {
   111  			errCh.SendContext(ctx, fmt.Errorf("second interceptor should not have %T in context", secondInterceptorCtxkey{}))
   112  		}
   113  		secondCtx := context.WithValue(ctx, secondInterceptorCtxkey{}, secondInterceptorCtxVal)
   114  		return invoker(secondCtx, method, req, reply, cc, opts...)
   115  	}
   116  
   117  	lastInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   118  		if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
   119  			errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, parentCtxVal))
   120  		}
   121  		if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal {
   122  			errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal))
   123  		}
   124  		if got, ok := ctx.Value(secondInterceptorCtxkey{}).(string); !ok || got != secondInterceptorCtxVal {
   125  			errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, secondInterceptorCtxVal))
   126  		}
   127  		errCh.SendContext(ctx, nil)
   128  		return invoker(ctx, method, req, reply, cc, opts...)
   129  	}
   130  
   131  	// Start a stub server and use the above chain of interceptors while creating
   132  	// a ClientConn to it.
   133  	ss := &stubserver.StubServer{
   134  		EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil },
   135  	}
   136  	if err := ss.Start(nil, grpc.WithChainUnaryInterceptor(firstInt, secondInt, lastInt)); err != nil {
   137  		t.Fatalf("Failed to start stub server: %v", err)
   138  	}
   139  	defer ss.Stop()
   140  
   141  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   142  	defer cancel()
   143  	if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil {
   144  		t.Fatalf("ss.Client.EmptyCall() failed: %v", err)
   145  	}
   146  	val, err := errCh.Receive(ctx)
   147  	if err != nil {
   148  		t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err)
   149  	}
   150  	if val != nil {
   151  		t.Fatalf("unary interceptor failed: %v", val)
   152  	}
   153  }
   154  
   155  // TestChainOnBaseUnaryClientInterceptor_ContextValuePropagation verifies that
   156  // unary interceptors specified as a base interceptor or as a chain interceptor
   157  // receive context values specified in the original call as well as the ones
   158  // specified by interceptors in the chain.
   159  func (s) TestChainOnBaseUnaryClientInterceptor_ContextValuePropagation(t *testing.T) {
   160  	errCh := testutils.NewChannel()
   161  	baseInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   162  		if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
   163  			errCh.SendContext(ctx, fmt.Errorf("base interceptor got %q in context.Val, want %q", got, parentCtxVal))
   164  		}
   165  		if ctx.Value(baseInterceptorCtxKey{}) != nil {
   166  			errCh.SendContext(ctx, fmt.Errorf("baseinterceptor should not have %T in context", baseInterceptorCtxKey{}))
   167  		}
   168  		baseCtx := context.WithValue(ctx, baseInterceptorCtxKey{}, baseInterceptorCtxVal)
   169  		return invoker(baseCtx, method, req, reply, cc, opts...)
   170  	}
   171  
   172  	chainInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   173  		if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
   174  			errCh.SendContext(ctx, fmt.Errorf("chain interceptor got %q in context.Val, want %q", got, parentCtxVal))
   175  		}
   176  		if got, ok := ctx.Value(baseInterceptorCtxKey{}).(string); !ok || got != baseInterceptorCtxVal {
   177  			errCh.SendContext(ctx, fmt.Errorf("chain interceptor got %q in context.Val, want %q", got, baseInterceptorCtxVal))
   178  		}
   179  		errCh.SendContext(ctx, nil)
   180  		return invoker(ctx, method, req, reply, cc, opts...)
   181  	}
   182  
   183  	// Start a stub server and use the above chain of interceptors while creating
   184  	// a ClientConn to it.
   185  	ss := &stubserver.StubServer{
   186  		EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil },
   187  	}
   188  	if err := ss.Start(nil, grpc.WithUnaryInterceptor(baseInt), grpc.WithChainUnaryInterceptor(chainInt)); err != nil {
   189  		t.Fatalf("Failed to start stub server: %v", err)
   190  	}
   191  	defer ss.Stop()
   192  
   193  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   194  	defer cancel()
   195  	if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil {
   196  		t.Fatalf("ss.Client.EmptyCall() failed: %v", err)
   197  	}
   198  	val, err := errCh.Receive(ctx)
   199  	if err != nil {
   200  		t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err)
   201  	}
   202  	if val != nil {
   203  		t.Fatalf("unary interceptor failed: %v", val)
   204  	}
   205  }
   206  
   207  // TestChainStreamClientInterceptor_ContextValuePropagation verifies that a
   208  // chain of stream interceptors receive context values specified in the original
   209  // call as well as the ones specified by the prior interceptors in the chain.
   210  func (s) TestChainStreamClientInterceptor_ContextValuePropagation(t *testing.T) {
   211  	errCh := testutils.NewChannel()
   212  	firstInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   213  		if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
   214  			errCh.SendContext(ctx, fmt.Errorf("first interceptor got %q in context.Val, want %q", got, parentCtxVal))
   215  		}
   216  		if ctx.Value(firstInterceptorCtxkey{}) != nil {
   217  			errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", firstInterceptorCtxkey{}))
   218  		}
   219  		if ctx.Value(secondInterceptorCtxkey{}) != nil {
   220  			errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", secondInterceptorCtxkey{}))
   221  		}
   222  		firstCtx := context.WithValue(ctx, firstInterceptorCtxkey{}, firstInterceptorCtxVal)
   223  		return streamer(firstCtx, desc, cc, method, opts...)
   224  	}
   225  
   226  	secondInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   227  		if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
   228  			errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, parentCtxVal))
   229  		}
   230  		if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal {
   231  			errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal))
   232  		}
   233  		if ctx.Value(secondInterceptorCtxkey{}) != nil {
   234  			errCh.SendContext(ctx, fmt.Errorf("second interceptor should not have %T in context", secondInterceptorCtxkey{}))
   235  		}
   236  		secondCtx := context.WithValue(ctx, secondInterceptorCtxkey{}, secondInterceptorCtxVal)
   237  		return streamer(secondCtx, desc, cc, method, opts...)
   238  	}
   239  
   240  	lastInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   241  		if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
   242  			errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, parentCtxVal))
   243  		}
   244  		if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal {
   245  			errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal))
   246  		}
   247  		if got, ok := ctx.Value(secondInterceptorCtxkey{}).(string); !ok || got != secondInterceptorCtxVal {
   248  			errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, secondInterceptorCtxVal))
   249  		}
   250  		errCh.SendContext(ctx, nil)
   251  		return streamer(ctx, desc, cc, method, opts...)
   252  	}
   253  
   254  	// Start a stub server and use the above chain of interceptors while creating
   255  	// a ClientConn to it.
   256  	ss := &stubserver.StubServer{
   257  		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
   258  			if _, err := stream.Recv(); err != nil {
   259  				return err
   260  			}
   261  			return stream.Send(&testpb.StreamingOutputCallResponse{})
   262  		},
   263  	}
   264  	if err := ss.Start(nil, grpc.WithChainStreamInterceptor(firstInt, secondInt, lastInt)); err != nil {
   265  		t.Fatalf("Failed to start stub server: %v", err)
   266  	}
   267  	defer ss.Stop()
   268  
   269  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   270  	defer cancel()
   271  	if _, err := ss.Client.FullDuplexCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal)); err != nil {
   272  		t.Fatalf("ss.Client.FullDuplexCall() failed: %v", err)
   273  	}
   274  	val, err := errCh.Receive(ctx)
   275  	if err != nil {
   276  		t.Fatalf("timeout when waiting for stream interceptor to be invoked: %v", err)
   277  	}
   278  	if val != nil {
   279  		t.Fatalf("stream interceptor failed: %v", val)
   280  	}
   281  }
   282  

View as plain text