1
2
3
4 package grpc_middleware
5
6 import (
7 "context"
8 "fmt"
9 "testing"
10
11 "github.com/stretchr/testify/require"
12 "google.golang.org/grpc"
13 "google.golang.org/grpc/metadata"
14 )
15
16 var (
17 someServiceName = "SomeService.StreamMethod"
18 parentUnaryInfo = &grpc.UnaryServerInfo{FullMethod: someServiceName}
19 parentStreamInfo = &grpc.StreamServerInfo{
20 FullMethod: someServiceName,
21 IsServerStream: true,
22 }
23 someValue = 1
24 parentContext = context.WithValue(context.TODO(), "parent", someValue)
25 )
26
27 func TestChainUnaryServer(t *testing.T) {
28 input := "input"
29 output := "output"
30
31 first := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
32 requireContextValue(t, ctx, "parent", "first interceptor must know the parent context value")
33 require.Equal(t, parentUnaryInfo, info, "first interceptor must know the someUnaryServerInfo")
34 ctx = context.WithValue(ctx, "first", 1)
35 return handler(ctx, req)
36 }
37 second := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
38 requireContextValue(t, ctx, "parent", "second interceptor must know the parent context value")
39 requireContextValue(t, ctx, "first", "second interceptor must know the first context value")
40 require.Equal(t, parentUnaryInfo, info, "second interceptor must know the someUnaryServerInfo")
41 ctx = context.WithValue(ctx, "second", 1)
42 return handler(ctx, req)
43 }
44 handler := func(ctx context.Context, req interface{}) (interface{}, error) {
45 require.EqualValues(t, input, req, "handler must get the input")
46 requireContextValue(t, ctx, "parent", "handler must know the parent context value")
47 requireContextValue(t, ctx, "first", "handler must know the first context value")
48 requireContextValue(t, ctx, "second", "handler must know the second context value")
49 return output, nil
50 }
51
52 chain := ChainUnaryServer(first, second)
53 out, _ := chain(parentContext, input, parentUnaryInfo, handler)
54 require.EqualValues(t, output, out, "chain must return handler's output")
55 }
56
57 func TestChainStreamServer(t *testing.T) {
58 someService := &struct{}{}
59 recvMessage := "received"
60 sentMessage := "sent"
61 outputError := fmt.Errorf("some error")
62
63 first := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
64 requireContextValue(t, stream.Context(), "parent", "first interceptor must know the parent context value")
65 require.Equal(t, parentStreamInfo, info, "first interceptor must know the parentStreamInfo")
66 require.Equal(t, someService, srv, "first interceptor must know someService")
67 wrapped := WrapServerStream(stream)
68 wrapped.WrappedContext = context.WithValue(stream.Context(), "first", 1)
69 return handler(srv, wrapped)
70 }
71 second := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
72 requireContextValue(t, stream.Context(), "parent", "second interceptor must know the parent context value")
73 requireContextValue(t, stream.Context(), "first", "second interceptor must know the first context value")
74 require.Equal(t, parentStreamInfo, info, "second interceptor must know the parentStreamInfo")
75 require.Equal(t, someService, srv, "second interceptor must know someService")
76 wrapped := WrapServerStream(stream)
77 wrapped.WrappedContext = context.WithValue(stream.Context(), "second", 1)
78 return handler(srv, wrapped)
79 }
80 handler := func(srv interface{}, stream grpc.ServerStream) error {
81 require.Equal(t, someService, srv, "handler must know someService")
82 requireContextValue(t, stream.Context(), "parent", "handler must know the parent context value")
83 requireContextValue(t, stream.Context(), "first", "handler must know the first context value")
84 requireContextValue(t, stream.Context(), "second", "handler must know the second context value")
85 require.NoError(t, stream.RecvMsg(recvMessage), "handler must have access to stream messages")
86 require.NoError(t, stream.SendMsg(sentMessage), "handler must be able to send stream messages")
87 return outputError
88 }
89 fakeStream := &fakeServerStream{ctx: parentContext, recvMessage: recvMessage}
90 chain := ChainStreamServer(first, second)
91 err := chain(someService, fakeStream, parentStreamInfo, handler)
92 require.Equal(t, outputError, err, "chain must return handler's error")
93 require.Equal(t, sentMessage, fakeStream.sentMessage, "handler's sent message must propagate to stream")
94 }
95
96 func TestChainUnaryClient(t *testing.T) {
97 ignoredMd := metadata.Pairs("foo", "bar")
98 parentOpts := []grpc.CallOption{grpc.Header(&ignoredMd)}
99 reqMessage := "request"
100 replyMessage := "reply"
101 outputError := fmt.Errorf("some error")
102
103 first := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
104 requireContextValue(t, ctx, "parent", "first must know the parent context value")
105 require.Equal(t, someServiceName, method, "first must know someService")
106 require.Len(t, opts, 1, "first should see parent CallOptions")
107 wrappedCtx := context.WithValue(ctx, "first", 1)
108 return invoker(wrappedCtx, method, req, reply, cc, opts...)
109 }
110 second := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
111 requireContextValue(t, ctx, "parent", "second must know the parent context value")
112 requireContextValue(t, ctx, "first", "second must know the first context value")
113 require.Equal(t, someServiceName, method, "second must know someService")
114 require.Len(t, opts, 1, "second should see parent CallOptions")
115 wrappedOpts := append(opts, grpc.WaitForReady(false))
116 wrappedCtx := context.WithValue(ctx, "second", 1)
117 return invoker(wrappedCtx, method, req, reply, cc, wrappedOpts...)
118 }
119 invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
120 require.Equal(t, someServiceName, method, "invoker must know someService")
121 requireContextValue(t, ctx, "parent", "invoker must know the parent context value")
122 requireContextValue(t, ctx, "first", "invoker must know the first context value")
123 requireContextValue(t, ctx, "second", "invoker must know the second context value")
124 require.Len(t, opts, 2, "invoker should see both CallOpts from second and parent")
125 return outputError
126 }
127 chain := ChainUnaryClient(first, second)
128 err := chain(parentContext, someServiceName, reqMessage, replyMessage, nil, invoker, parentOpts...)
129 require.Equal(t, outputError, err, "chain must return invokers's error")
130 }
131
132 func TestChainStreamClient(t *testing.T) {
133 ignoredMd := metadata.Pairs("foo", "bar")
134 parentOpts := []grpc.CallOption{grpc.Header(&ignoredMd)}
135 clientStream := &fakeClientStream{}
136 fakeStreamDesc := &grpc.StreamDesc{ClientStreams: true, ServerStreams: true, StreamName: someServiceName}
137
138 first := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
139 requireContextValue(t, ctx, "parent", "first must know the parent context value")
140 require.Equal(t, someServiceName, method, "first must know someService")
141 require.Len(t, opts, 1, "first should see parent CallOptions")
142 wrappedCtx := context.WithValue(ctx, "first", 1)
143 return streamer(wrappedCtx, desc, cc, method, opts...)
144 }
145 second := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
146 requireContextValue(t, ctx, "parent", "second must know the parent context value")
147 requireContextValue(t, ctx, "first", "second must know the first context value")
148 require.Equal(t, someServiceName, method, "second must know someService")
149 require.Len(t, opts, 1, "second should see parent CallOptions")
150 wrappedOpts := append(opts, grpc.WaitForReady(false))
151 wrappedCtx := context.WithValue(ctx, "second", 1)
152 return streamer(wrappedCtx, desc, cc, method, wrappedOpts...)
153 }
154 streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
155 require.Equal(t, someServiceName, method, "streamer must know someService")
156 require.Equal(t, fakeStreamDesc, desc, "streamer must see the right StreamDesc")
157
158 requireContextValue(t, ctx, "parent", "streamer must know the parent context value")
159 requireContextValue(t, ctx, "first", "streamer must know the first context value")
160 requireContextValue(t, ctx, "second", "streamer must know the second context value")
161 require.Len(t, opts, 2, "streamer should see both CallOpts from second and parent")
162 return clientStream, nil
163 }
164 chain := ChainStreamClient(first, second)
165 someStream, err := chain(parentContext, fakeStreamDesc, nil, someServiceName, streamer, parentOpts...)
166 require.NoError(t, err, "chain must not return an error")
167 require.Equal(t, clientStream, someStream, "chain must return invokers's clientstream")
168 }
169
170 func requireContextValue(t *testing.T, ctx context.Context, key string, msg ...interface{}) {
171 val := ctx.Value(key)
172 require.NotNil(t, val, msg...)
173 require.Equal(t, someValue, val, msg...)
174 }
175
View as plain text