...

Source file src/github.com/google/s2a-go/internal/record/ticketsender_test.go

Documentation: github.com/google/s2a-go/internal/record

     1  /*
     2   *
     3   * Copyright 2021 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package record
    20  
    21  import (
    22  	"errors"
    23  	"fmt"
    24  	"testing"
    25  
    26  	"github.com/google/go-cmp/cmp"
    27  	"github.com/google/go-cmp/cmp/cmpopts"
    28  	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
    29  	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
    30  	"github.com/google/s2a-go/internal/tokenmanager"
    31  	"google.golang.org/grpc/codes"
    32  	"google.golang.org/protobuf/testing/protocmp"
    33  )
    34  
    35  const (
    36  	testAccessToken = "test_access_token"
    37  )
    38  
    39  type fakeStream struct {
    40  	// returnInvalid is a flag indicating whether the return status of Recv is
    41  	// OK or not.
    42  	returnInvalid bool
    43  	// returnRecvErr is a flag indicating whether an error should be returned by
    44  	// Recv.
    45  	returnRecvErr bool
    46  }
    47  
    48  func (fs *fakeStream) Send(req *s2apb.SessionReq) error {
    49  	if len(req.GetResumptionTicket().GetInBytes()) == 0 {
    50  		return errors.New("fakeStream Send received an empty InBytes")
    51  	}
    52  	if req.GetResumptionTicket().GetConnectionId() == 0 {
    53  		return errors.New("fakeStream Send received a 0 ConnectionId")
    54  	}
    55  	if req.GetResumptionTicket().GetLocalIdentity() == nil {
    56  		return errors.New("fakeStream Send received an empty LocalIdentity")
    57  	}
    58  	return nil
    59  }
    60  
    61  func (fs *fakeStream) Recv() (*s2apb.SessionResp, error) {
    62  	if fs.returnRecvErr {
    63  		return nil, errors.New("fakeStream Recv error")
    64  	}
    65  	if fs.returnInvalid {
    66  		return &s2apb.SessionResp{
    67  			Status: &s2apb.SessionStatus{Code: uint32(codes.InvalidArgument)},
    68  		}, nil
    69  	}
    70  	return &s2apb.SessionResp{
    71  		Status: &s2apb.SessionStatus{Code: uint32(codes.OK)},
    72  	}, nil
    73  }
    74  
    75  type fakeAccessTokenManager struct {
    76  	acceptedIdentity   *commonpb.Identity
    77  	accessToken        string
    78  	allowEmptyIdentity bool
    79  }
    80  
    81  func (m *fakeAccessTokenManager) DefaultToken() (string, error) {
    82  	if !m.allowEmptyIdentity {
    83  		return "", fmt.Errorf("not allowed to get token for empty identity")
    84  	}
    85  	return m.accessToken, nil
    86  }
    87  
    88  func (m *fakeAccessTokenManager) Token(identity *commonpb.Identity) (string, error) {
    89  	if identity == nil || cmp.Equal(identity, &commonpb.Identity{}, protocmp.Transform()) {
    90  		if !m.allowEmptyIdentity {
    91  			return "", fmt.Errorf("not allowed to get token for empty identity")
    92  		}
    93  		return m.accessToken, nil
    94  	}
    95  	if cmp.Equal(identity, m.acceptedIdentity, protocmp.Transform()) {
    96  		return m.accessToken, nil
    97  	}
    98  	return "", fmt.Errorf("unable to get token")
    99  }
   100  
   101  func TestWriteTicketsToStream(t *testing.T) {
   102  	for _, tc := range []struct {
   103  		returnInvalid   bool
   104  		returnRecvError bool
   105  	}{
   106  		{
   107  			// Both flags are set to false.
   108  		},
   109  		{
   110  			returnInvalid: true,
   111  		},
   112  		{
   113  			returnRecvError: true,
   114  		},
   115  	} {
   116  		sender := ticketSender{
   117  			connectionID: 1,
   118  			localIdentity: &commonpb.Identity{
   119  				IdentityOneof: &commonpb.Identity_SpiffeId{
   120  					SpiffeId: "test_spiffe_id",
   121  				},
   122  			},
   123  		}
   124  		fs := &fakeStream{returnInvalid: tc.returnInvalid, returnRecvErr: tc.returnRecvError}
   125  		if got, want := sender.writeTicketsToStream(fs, make([][]byte, 1)) == nil, !tc.returnRecvError && !tc.returnInvalid; got != want {
   126  			t.Errorf("sender.writeTicketsToStream(%v, _) = (err=nil) = %v, want %v", fs, got, want)
   127  		}
   128  	}
   129  }
   130  
   131  func TestGetAuthMechanism(t *testing.T) {
   132  	sortProtos := cmpopts.SortSlices(func(m1, m2 *s2apb.AuthenticationMechanism) bool { return m1.String() < m2.String() })
   133  	for _, tc := range []struct {
   134  		description            string
   135  		localIdentity          *commonpb.Identity
   136  		tokenManager           tokenmanager.AccessTokenManager
   137  		expectedAuthMechanisms []*s2apb.AuthenticationMechanism
   138  	}{
   139  		{
   140  			description:            "token manager is nil",
   141  			tokenManager:           nil,
   142  			expectedAuthMechanisms: nil,
   143  		},
   144  		{
   145  			description: "token manager expects empty identity",
   146  			tokenManager: &fakeAccessTokenManager{
   147  				accessToken:        testAccessToken,
   148  				allowEmptyIdentity: true,
   149  			},
   150  			expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
   151  				{
   152  					MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   153  						Token: testAccessToken,
   154  					},
   155  				},
   156  			},
   157  		},
   158  		{
   159  			description: "token manager does not expect empty identity",
   160  			tokenManager: &fakeAccessTokenManager{
   161  				allowEmptyIdentity: false,
   162  			},
   163  			expectedAuthMechanisms: nil,
   164  		},
   165  		{
   166  			description: "token manager expects SPIFFE ID",
   167  			localIdentity: &commonpb.Identity{
   168  				IdentityOneof: &commonpb.Identity_SpiffeId{
   169  					SpiffeId: "allowed_spiffe_id",
   170  				},
   171  			},
   172  			tokenManager: &fakeAccessTokenManager{
   173  				accessToken: testAccessToken,
   174  				acceptedIdentity: &commonpb.Identity{
   175  					IdentityOneof: &commonpb.Identity_SpiffeId{
   176  						SpiffeId: "allowed_spiffe_id",
   177  					},
   178  				},
   179  			},
   180  			expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
   181  				{
   182  					Identity: &commonpb.Identity{
   183  						IdentityOneof: &commonpb.Identity_SpiffeId{
   184  							SpiffeId: "allowed_spiffe_id",
   185  						},
   186  					},
   187  					MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   188  						Token: testAccessToken,
   189  					},
   190  				},
   191  			},
   192  		},
   193  		{
   194  			description: "token manager does not expect hostname",
   195  
   196  			localIdentity: &commonpb.Identity{
   197  				IdentityOneof: &commonpb.Identity_Hostname{
   198  					Hostname: "disallowed_hostname",
   199  				},
   200  			},
   201  			tokenManager:           &fakeAccessTokenManager{},
   202  			expectedAuthMechanisms: nil,
   203  		},
   204  	} {
   205  		t.Run(tc.description, func(t *testing.T) {
   206  			ticketSender := &ticketSender{
   207  				localIdentity: tc.localIdentity,
   208  				tokenManager:  tc.tokenManager,
   209  			}
   210  			authMechanisms := ticketSender.getAuthMechanisms()
   211  			if got, want := (authMechanisms == nil), (tc.expectedAuthMechanisms == nil); got != want {
   212  				t.Errorf("authMechanisms == nil: %t, tc.expectedAuthMechanisms == nil: %t", got, want)
   213  			}
   214  			if authMechanisms != nil && tc.expectedAuthMechanisms != nil {
   215  				if diff := cmp.Diff(authMechanisms, tc.expectedAuthMechanisms, protocmp.Transform(), sortProtos); diff != "" {
   216  					t.Errorf("ticketSender.getAuthMechanisms() returned incorrect slice, (-want +got):\n%s", diff)
   217  				}
   218  			}
   219  		})
   220  	}
   221  }
   222  

View as plain text