1
2
3
4 package grpc_auth_test
5
6 import (
7 "context"
8 "fmt"
9 "testing"
10 "time"
11
12 "github.com/grpc-ecosystem/go-grpc-middleware/auth"
13 "github.com/grpc-ecosystem/go-grpc-middleware/testing"
14 pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
15 "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
16 "github.com/stretchr/testify/assert"
17 "github.com/stretchr/testify/require"
18 "github.com/stretchr/testify/suite"
19 "golang.org/x/oauth2"
20 "google.golang.org/grpc"
21 "google.golang.org/grpc/codes"
22 "google.golang.org/grpc/credentials/oauth"
23 "google.golang.org/grpc/metadata"
24 "google.golang.org/grpc/status"
25 )
26
27 var (
28 commonAuthToken = "some_good_token"
29 overrideAuthToken = "override_token"
30
31 authedMarker = "some_context_marker"
32 goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
33 )
34
35
36
37 func buildDummyAuthFunction(expectedScheme string, expectedToken string) func(ctx context.Context) (context.Context, error) {
38 return func(ctx context.Context) (context.Context, error) {
39 token, err := grpc_auth.AuthFromMD(ctx, expectedScheme)
40 if err != nil {
41 return nil, err
42 }
43 if token != expectedToken {
44 return nil, status.Errorf(codes.PermissionDenied, "buildDummyAuthFunction bad token")
45 }
46 return context.WithValue(ctx, authedMarker, "marker_exists"), nil
47 }
48 }
49
50 func assertAuthMarkerExists(t *testing.T, ctx context.Context) {
51 assert.Equal(t, "marker_exists", ctx.Value(authedMarker).(string), "auth marker from buildDummyAuthFunction must be passed around")
52 }
53
54 type assertingPingService struct {
55 pb_testproto.TestServiceServer
56 T *testing.T
57 }
58
59 func (s *assertingPingService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) {
60 assertAuthMarkerExists(s.T, ctx)
61 return s.TestServiceServer.PingError(ctx, ping)
62 }
63
64 func (s *assertingPingService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error {
65 assertAuthMarkerExists(s.T, stream.Context())
66 return s.TestServiceServer.PingList(ping, stream)
67 }
68
69 func ctxWithToken(ctx context.Context, scheme string, token string) context.Context {
70 md := metadata.Pairs("authorization", fmt.Sprintf("%s %v", scheme, token))
71 nCtx := metautils.NiceMD(md).ToOutgoing(ctx)
72 return nCtx
73 }
74
75 func TestAuthTestSuite(t *testing.T) {
76 authFunc := buildDummyAuthFunction("bearer", commonAuthToken)
77 s := &AuthTestSuite{
78 InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
79 TestService: &assertingPingService{&grpc_testing.TestPingService{T: t}, t},
80 ServerOpts: []grpc.ServerOption{
81 grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)),
82 grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)),
83 },
84 },
85 }
86 suite.Run(t, s)
87 }
88
89 type AuthTestSuite struct {
90 *grpc_testing.InterceptorTestSuite
91 }
92
93 func (s *AuthTestSuite) TestUnary_NoAuth() {
94 _, err := s.Client.Ping(s.SimpleCtx(), goodPing)
95 assert.Error(s.T(), err, "there must be an error")
96 assert.Equal(s.T(), codes.Unauthenticated, status.Code(err), "must error with unauthenticated")
97 }
98
99 func (s *AuthTestSuite) TestUnary_BadAuth() {
100 _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing)
101 assert.Error(s.T(), err, "there must be an error")
102 assert.Equal(s.T(), codes.PermissionDenied, status.Code(err), "must error with permission denied")
103 }
104
105 func (s *AuthTestSuite) TestUnary_PassesAuth() {
106 _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", commonAuthToken), goodPing)
107 require.NoError(s.T(), err, "no error must occur")
108 }
109
110 func (s *AuthTestSuite) TestUnary_PassesWithPerRpcCredentials() {
111 grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}}
112 client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds))
113 _, err := client.Ping(s.SimpleCtx(), goodPing)
114 require.NoError(s.T(), err, "no error must occur")
115 }
116
117 func (s *AuthTestSuite) TestStream_NoAuth() {
118 stream, err := s.Client.PingList(s.SimpleCtx(), goodPing)
119 require.NoError(s.T(), err, "should not fail on establishing the stream")
120 _, err = stream.Recv()
121 assert.Error(s.T(), err, "there must be an error")
122 assert.Equal(s.T(), codes.Unauthenticated, status.Code(err), "must error with unauthenticated")
123 }
124
125 func (s *AuthTestSuite) TestStream_BadAuth() {
126 stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing)
127 require.NoError(s.T(), err, "should not fail on establishing the stream")
128 _, err = stream.Recv()
129 assert.Error(s.T(), err, "there must be an error")
130 assert.Equal(s.T(), codes.PermissionDenied, status.Code(err), "must error with permission denied")
131 }
132
133 func (s *AuthTestSuite) TestStream_PassesAuth() {
134 stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", commonAuthToken), goodPing)
135 require.NoError(s.T(), err, "should not fail on establishing the stream")
136 pong, err := stream.Recv()
137 require.NoError(s.T(), err, "no error must occur")
138 require.NotNil(s.T(), pong, "pong must not be nil")
139 }
140
141 func (s *AuthTestSuite) TestStream_PassesWithPerRpcCredentials() {
142 grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}}
143 client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds))
144 stream, err := client.PingList(s.SimpleCtx(), goodPing)
145 require.NoError(s.T(), err, "should not fail on establishing the stream")
146 pong, err := stream.Recv()
147 require.NoError(s.T(), err, "no error must occur")
148 require.NotNil(s.T(), pong, "pong must not be nil")
149 }
150
151 type authOverrideTestService struct {
152 pb_testproto.TestServiceServer
153 T *testing.T
154 }
155
156 func (s *authOverrideTestService) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) {
157 assert.NotEmpty(s.T, fullMethodName, "method name of caller is passed around")
158 return buildDummyAuthFunction("bearer", overrideAuthToken)(ctx)
159 }
160
161 func TestAuthOverrideTestSuite(t *testing.T) {
162 authFunc := buildDummyAuthFunction("bearer", commonAuthToken)
163 s := &AuthOverrideTestSuite{
164 InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
165 TestService: &authOverrideTestService{&assertingPingService{&grpc_testing.TestPingService{T: t}, t}, t},
166 ServerOpts: []grpc.ServerOption{
167 grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)),
168 grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)),
169 },
170 },
171 }
172 suite.Run(t, s)
173 }
174
175 type AuthOverrideTestSuite struct {
176 *grpc_testing.InterceptorTestSuite
177 }
178
179 func (s *AuthOverrideTestSuite) TestUnary_PassesAuth() {
180 _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", overrideAuthToken), goodPing)
181 require.NoError(s.T(), err, "no error must occur")
182 }
183
184 func (s *AuthOverrideTestSuite) TestStream_PassesAuth() {
185 stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", overrideAuthToken), goodPing)
186 require.NoError(s.T(), err, "should not fail on establishing the stream")
187 pong, err := stream.Recv()
188 require.NoError(s.T(), err, "no error must occur")
189 require.NotNil(s.T(), pong, "pong must not be nil")
190 }
191
192
193 type fakeOAuth2TokenSource struct {
194 accessToken string
195 }
196
197 func (ts *fakeOAuth2TokenSource) Token() (*oauth2.Token, error) {
198 t := &oauth2.Token{
199 AccessToken: ts.accessToken,
200 Expiry: time.Now().Add(1 * time.Minute),
201 TokenType: "bearer",
202 }
203 return t, nil
204 }
205
View as plain text