...

Source file src/github.com/grpc-ecosystem/go-grpc-middleware/chain_test.go

Documentation: github.com/grpc-ecosystem/go-grpc-middleware

     1  // Copyright 2016 Michal Witkowski. All Rights Reserved.
     2  // See LICENSE for licensing terms.
     3  
     4  package grpc_middleware
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/require"
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/metadata"
    14  )
    15  
    16  var (
    17  	someServiceName  = "SomeService.StreamMethod"
    18  	parentUnaryInfo  = &grpc.UnaryServerInfo{FullMethod: someServiceName}
    19  	parentStreamInfo = &grpc.StreamServerInfo{
    20  		FullMethod:     someServiceName,
    21  		IsServerStream: true,
    22  	}
    23  	someValue     = 1
    24  	parentContext = context.WithValue(context.TODO(), "parent", someValue)
    25  )
    26  
    27  func TestChainUnaryServer(t *testing.T) {
    28  	input := "input"
    29  	output := "output"
    30  
    31  	first := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    32  		requireContextValue(t, ctx, "parent", "first interceptor must know the parent context value")
    33  		require.Equal(t, parentUnaryInfo, info, "first interceptor must know the someUnaryServerInfo")
    34  		ctx = context.WithValue(ctx, "first", 1)
    35  		return handler(ctx, req)
    36  	}
    37  	second := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    38  		requireContextValue(t, ctx, "parent", "second interceptor must know the parent context value")
    39  		requireContextValue(t, ctx, "first", "second interceptor must know the first context value")
    40  		require.Equal(t, parentUnaryInfo, info, "second interceptor must know the someUnaryServerInfo")
    41  		ctx = context.WithValue(ctx, "second", 1)
    42  		return handler(ctx, req)
    43  	}
    44  	handler := func(ctx context.Context, req interface{}) (interface{}, error) {
    45  		require.EqualValues(t, input, req, "handler must get the input")
    46  		requireContextValue(t, ctx, "parent", "handler must know the parent context value")
    47  		requireContextValue(t, ctx, "first", "handler must know the first context value")
    48  		requireContextValue(t, ctx, "second", "handler must know the second context value")
    49  		return output, nil
    50  	}
    51  
    52  	chain := ChainUnaryServer(first, second)
    53  	out, _ := chain(parentContext, input, parentUnaryInfo, handler)
    54  	require.EqualValues(t, output, out, "chain must return handler's output")
    55  }
    56  
    57  func TestChainStreamServer(t *testing.T) {
    58  	someService := &struct{}{}
    59  	recvMessage := "received"
    60  	sentMessage := "sent"
    61  	outputError := fmt.Errorf("some error")
    62  
    63  	first := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    64  		requireContextValue(t, stream.Context(), "parent", "first interceptor must know the parent context value")
    65  		require.Equal(t, parentStreamInfo, info, "first interceptor must know the parentStreamInfo")
    66  		require.Equal(t, someService, srv, "first interceptor must know someService")
    67  		wrapped := WrapServerStream(stream)
    68  		wrapped.WrappedContext = context.WithValue(stream.Context(), "first", 1)
    69  		return handler(srv, wrapped)
    70  	}
    71  	second := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    72  		requireContextValue(t, stream.Context(), "parent", "second interceptor must know the parent context value")
    73  		requireContextValue(t, stream.Context(), "first", "second interceptor must know the first context value")
    74  		require.Equal(t, parentStreamInfo, info, "second interceptor must know the parentStreamInfo")
    75  		require.Equal(t, someService, srv, "second interceptor must know someService")
    76  		wrapped := WrapServerStream(stream)
    77  		wrapped.WrappedContext = context.WithValue(stream.Context(), "second", 1)
    78  		return handler(srv, wrapped)
    79  	}
    80  	handler := func(srv interface{}, stream grpc.ServerStream) error {
    81  		require.Equal(t, someService, srv, "handler must know someService")
    82  		requireContextValue(t, stream.Context(), "parent", "handler must know the parent context value")
    83  		requireContextValue(t, stream.Context(), "first", "handler must know the first context value")
    84  		requireContextValue(t, stream.Context(), "second", "handler must know the second context value")
    85  		require.NoError(t, stream.RecvMsg(recvMessage), "handler must have access to stream messages")
    86  		require.NoError(t, stream.SendMsg(sentMessage), "handler must be able to send stream messages")
    87  		return outputError
    88  	}
    89  	fakeStream := &fakeServerStream{ctx: parentContext, recvMessage: recvMessage}
    90  	chain := ChainStreamServer(first, second)
    91  	err := chain(someService, fakeStream, parentStreamInfo, handler)
    92  	require.Equal(t, outputError, err, "chain must return handler's error")
    93  	require.Equal(t, sentMessage, fakeStream.sentMessage, "handler's sent message must propagate to stream")
    94  }
    95  
    96  func TestChainUnaryClient(t *testing.T) {
    97  	ignoredMd := metadata.Pairs("foo", "bar")
    98  	parentOpts := []grpc.CallOption{grpc.Header(&ignoredMd)}
    99  	reqMessage := "request"
   100  	replyMessage := "reply"
   101  	outputError := fmt.Errorf("some error")
   102  
   103  	first := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   104  		requireContextValue(t, ctx, "parent", "first must know the parent context value")
   105  		require.Equal(t, someServiceName, method, "first must know someService")
   106  		require.Len(t, opts, 1, "first should see parent CallOptions")
   107  		wrappedCtx := context.WithValue(ctx, "first", 1)
   108  		return invoker(wrappedCtx, method, req, reply, cc, opts...)
   109  	}
   110  	second := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   111  		requireContextValue(t, ctx, "parent", "second must know the parent context value")
   112  		requireContextValue(t, ctx, "first", "second must know the first context value")
   113  		require.Equal(t, someServiceName, method, "second must know someService")
   114  		require.Len(t, opts, 1, "second should see parent CallOptions")
   115  		wrappedOpts := append(opts, grpc.WaitForReady(false))
   116  		wrappedCtx := context.WithValue(ctx, "second", 1)
   117  		return invoker(wrappedCtx, method, req, reply, cc, wrappedOpts...)
   118  	}
   119  	invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
   120  		require.Equal(t, someServiceName, method, "invoker must know someService")
   121  		requireContextValue(t, ctx, "parent", "invoker must know the parent context value")
   122  		requireContextValue(t, ctx, "first", "invoker must know the first context value")
   123  		requireContextValue(t, ctx, "second", "invoker must know the second context value")
   124  		require.Len(t, opts, 2, "invoker should see both CallOpts from second and parent")
   125  		return outputError
   126  	}
   127  	chain := ChainUnaryClient(first, second)
   128  	err := chain(parentContext, someServiceName, reqMessage, replyMessage, nil, invoker, parentOpts...)
   129  	require.Equal(t, outputError, err, "chain must return invokers's error")
   130  }
   131  
   132  func TestChainStreamClient(t *testing.T) {
   133  	ignoredMd := metadata.Pairs("foo", "bar")
   134  	parentOpts := []grpc.CallOption{grpc.Header(&ignoredMd)}
   135  	clientStream := &fakeClientStream{}
   136  	fakeStreamDesc := &grpc.StreamDesc{ClientStreams: true, ServerStreams: true, StreamName: someServiceName}
   137  
   138  	first := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   139  		requireContextValue(t, ctx, "parent", "first must know the parent context value")
   140  		require.Equal(t, someServiceName, method, "first must know someService")
   141  		require.Len(t, opts, 1, "first should see parent CallOptions")
   142  		wrappedCtx := context.WithValue(ctx, "first", 1)
   143  		return streamer(wrappedCtx, desc, cc, method, opts...)
   144  	}
   145  	second := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   146  		requireContextValue(t, ctx, "parent", "second must know the parent context value")
   147  		requireContextValue(t, ctx, "first", "second must know the first context value")
   148  		require.Equal(t, someServiceName, method, "second must know someService")
   149  		require.Len(t, opts, 1, "second should see parent CallOptions")
   150  		wrappedOpts := append(opts, grpc.WaitForReady(false))
   151  		wrappedCtx := context.WithValue(ctx, "second", 1)
   152  		return streamer(wrappedCtx, desc, cc, method, wrappedOpts...)
   153  	}
   154  	streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   155  		require.Equal(t, someServiceName, method, "streamer must know someService")
   156  		require.Equal(t, fakeStreamDesc, desc, "streamer must see the right StreamDesc")
   157  
   158  		requireContextValue(t, ctx, "parent", "streamer must know the parent context value")
   159  		requireContextValue(t, ctx, "first", "streamer must know the first context value")
   160  		requireContextValue(t, ctx, "second", "streamer must know the second context value")
   161  		require.Len(t, opts, 2, "streamer should see both CallOpts from second and parent")
   162  		return clientStream, nil
   163  	}
   164  	chain := ChainStreamClient(first, second)
   165  	someStream, err := chain(parentContext, fakeStreamDesc, nil, someServiceName, streamer, parentOpts...)
   166  	require.NoError(t, err, "chain must not return an error")
   167  	require.Equal(t, clientStream, someStream, "chain must return invokers's clientstream")
   168  }
   169  
   170  func requireContextValue(t *testing.T, ctx context.Context, key string, msg ...interface{}) {
   171  	val := ctx.Value(key)
   172  	require.NotNil(t, val, msg...)
   173  	require.Equal(t, someValue, val, msg...)
   174  }
   175  

View as plain text