...

Source file src/github.com/grpc-ecosystem/go-grpc-middleware/auth/auth_test.go

Documentation: github.com/grpc-ecosystem/go-grpc-middleware/auth

     1  // Copyright 2016 Michal Witkowski. All Rights Reserved.
     2  // See LICENSE for licensing terms.
     3  
     4  package grpc_auth_test
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/grpc-ecosystem/go-grpc-middleware/auth"
    13  	"github.com/grpc-ecosystem/go-grpc-middleware/testing"
    14  	pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
    15  	"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/stretchr/testify/require"
    18  	"github.com/stretchr/testify/suite"
    19  	"golang.org/x/oauth2"
    20  	"google.golang.org/grpc"
    21  	"google.golang.org/grpc/codes"
    22  	"google.golang.org/grpc/credentials/oauth"
    23  	"google.golang.org/grpc/metadata"
    24  	"google.golang.org/grpc/status"
    25  )
    26  
    27  var (
    28  	commonAuthToken   = "some_good_token"
    29  	overrideAuthToken = "override_token"
    30  
    31  	authedMarker = "some_context_marker"
    32  	goodPing     = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
    33  )
    34  
    35  // TODO(mwitkow): Add auth from metadata client dialer, which requires TLS.
    36  
    37  func buildDummyAuthFunction(expectedScheme string, expectedToken string) func(ctx context.Context) (context.Context, error) {
    38  	return func(ctx context.Context) (context.Context, error) {
    39  		token, err := grpc_auth.AuthFromMD(ctx, expectedScheme)
    40  		if err != nil {
    41  			return nil, err
    42  		}
    43  		if token != expectedToken {
    44  			return nil, status.Errorf(codes.PermissionDenied, "buildDummyAuthFunction bad token")
    45  		}
    46  		return context.WithValue(ctx, authedMarker, "marker_exists"), nil
    47  	}
    48  }
    49  
    50  func assertAuthMarkerExists(t *testing.T, ctx context.Context) {
    51  	assert.Equal(t, "marker_exists", ctx.Value(authedMarker).(string), "auth marker from buildDummyAuthFunction must be passed around")
    52  }
    53  
    54  type assertingPingService struct {
    55  	pb_testproto.TestServiceServer
    56  	T *testing.T
    57  }
    58  
    59  func (s *assertingPingService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) {
    60  	assertAuthMarkerExists(s.T, ctx)
    61  	return s.TestServiceServer.PingError(ctx, ping)
    62  }
    63  
    64  func (s *assertingPingService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error {
    65  	assertAuthMarkerExists(s.T, stream.Context())
    66  	return s.TestServiceServer.PingList(ping, stream)
    67  }
    68  
    69  func ctxWithToken(ctx context.Context, scheme string, token string) context.Context {
    70  	md := metadata.Pairs("authorization", fmt.Sprintf("%s %v", scheme, token))
    71  	nCtx := metautils.NiceMD(md).ToOutgoing(ctx)
    72  	return nCtx
    73  }
    74  
    75  func TestAuthTestSuite(t *testing.T) {
    76  	authFunc := buildDummyAuthFunction("bearer", commonAuthToken)
    77  	s := &AuthTestSuite{
    78  		InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
    79  			TestService: &assertingPingService{&grpc_testing.TestPingService{T: t}, t},
    80  			ServerOpts: []grpc.ServerOption{
    81  				grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)),
    82  				grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)),
    83  			},
    84  		},
    85  	}
    86  	suite.Run(t, s)
    87  }
    88  
    89  type AuthTestSuite struct {
    90  	*grpc_testing.InterceptorTestSuite
    91  }
    92  
    93  func (s *AuthTestSuite) TestUnary_NoAuth() {
    94  	_, err := s.Client.Ping(s.SimpleCtx(), goodPing)
    95  	assert.Error(s.T(), err, "there must be an error")
    96  	assert.Equal(s.T(), codes.Unauthenticated, status.Code(err), "must error with unauthenticated")
    97  }
    98  
    99  func (s *AuthTestSuite) TestUnary_BadAuth() {
   100  	_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing)
   101  	assert.Error(s.T(), err, "there must be an error")
   102  	assert.Equal(s.T(), codes.PermissionDenied, status.Code(err), "must error with permission denied")
   103  }
   104  
   105  func (s *AuthTestSuite) TestUnary_PassesAuth() {
   106  	_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", commonAuthToken), goodPing)
   107  	require.NoError(s.T(), err, "no error must occur")
   108  }
   109  
   110  func (s *AuthTestSuite) TestUnary_PassesWithPerRpcCredentials() {
   111  	grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}}
   112  	client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds))
   113  	_, err := client.Ping(s.SimpleCtx(), goodPing)
   114  	require.NoError(s.T(), err, "no error must occur")
   115  }
   116  
   117  func (s *AuthTestSuite) TestStream_NoAuth() {
   118  	stream, err := s.Client.PingList(s.SimpleCtx(), goodPing)
   119  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   120  	_, err = stream.Recv()
   121  	assert.Error(s.T(), err, "there must be an error")
   122  	assert.Equal(s.T(), codes.Unauthenticated, status.Code(err), "must error with unauthenticated")
   123  }
   124  
   125  func (s *AuthTestSuite) TestStream_BadAuth() {
   126  	stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing)
   127  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   128  	_, err = stream.Recv()
   129  	assert.Error(s.T(), err, "there must be an error")
   130  	assert.Equal(s.T(), codes.PermissionDenied, status.Code(err), "must error with permission denied")
   131  }
   132  
   133  func (s *AuthTestSuite) TestStream_PassesAuth() {
   134  	stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", commonAuthToken), goodPing)
   135  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   136  	pong, err := stream.Recv()
   137  	require.NoError(s.T(), err, "no error must occur")
   138  	require.NotNil(s.T(), pong, "pong must not be nil")
   139  }
   140  
   141  func (s *AuthTestSuite) TestStream_PassesWithPerRpcCredentials() {
   142  	grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}}
   143  	client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds))
   144  	stream, err := client.PingList(s.SimpleCtx(), goodPing)
   145  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   146  	pong, err := stream.Recv()
   147  	require.NoError(s.T(), err, "no error must occur")
   148  	require.NotNil(s.T(), pong, "pong must not be nil")
   149  }
   150  
   151  type authOverrideTestService struct {
   152  	pb_testproto.TestServiceServer
   153  	T *testing.T
   154  }
   155  
   156  func (s *authOverrideTestService) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) {
   157  	assert.NotEmpty(s.T, fullMethodName, "method name of caller is passed around")
   158  	return buildDummyAuthFunction("bearer", overrideAuthToken)(ctx)
   159  }
   160  
   161  func TestAuthOverrideTestSuite(t *testing.T) {
   162  	authFunc := buildDummyAuthFunction("bearer", commonAuthToken)
   163  	s := &AuthOverrideTestSuite{
   164  		InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
   165  			TestService: &authOverrideTestService{&assertingPingService{&grpc_testing.TestPingService{T: t}, t}, t},
   166  			ServerOpts: []grpc.ServerOption{
   167  				grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)),
   168  				grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)),
   169  			},
   170  		},
   171  	}
   172  	suite.Run(t, s)
   173  }
   174  
   175  type AuthOverrideTestSuite struct {
   176  	*grpc_testing.InterceptorTestSuite
   177  }
   178  
   179  func (s *AuthOverrideTestSuite) TestUnary_PassesAuth() {
   180  	_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", overrideAuthToken), goodPing)
   181  	require.NoError(s.T(), err, "no error must occur")
   182  }
   183  
   184  func (s *AuthOverrideTestSuite) TestStream_PassesAuth() {
   185  	stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", overrideAuthToken), goodPing)
   186  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   187  	pong, err := stream.Recv()
   188  	require.NoError(s.T(), err, "no error must occur")
   189  	require.NotNil(s.T(), pong, "pong must not be nil")
   190  }
   191  
   192  // fakeOAuth2TokenSource implements a fake oauth2.TokenSource for the purpose of credentials test.
   193  type fakeOAuth2TokenSource struct {
   194  	accessToken string
   195  }
   196  
   197  func (ts *fakeOAuth2TokenSource) Token() (*oauth2.Token, error) {
   198  	t := &oauth2.Token{
   199  		AccessToken: ts.accessToken,
   200  		Expiry:      time.Now().Add(1 * time.Minute),
   201  		TokenType:   "bearer",
   202  	}
   203  	return t, nil
   204  }
   205  

View as plain text