...

Source file src/github.com/google/s2a-go/internal/fakehandshaker/service/s2a_service_test.go

Documentation: github.com/google/s2a-go/internal/fakehandshaker/service

     1  /*
     2   *
     3   * Copyright 2021 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package service
    20  
    21  import (
    22  	"errors"
    23  	"os"
    24  	"strings"
    25  	"testing"
    26  
    27  	"github.com/google/go-cmp/cmp"
    28  	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
    29  	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
    30  	"google.golang.org/grpc"
    31  	"google.golang.org/grpc/codes"
    32  	"google.golang.org/protobuf/testing/protocmp"
    33  )
    34  
    35  const (
    36  	testAccessToken = "test_access_token"
    37  )
    38  
    39  type fakeS2ASetupSessionServer struct {
    40  	grpc.ServerStream
    41  	recvCount int
    42  	reqs      []*s2apb.SessionReq
    43  	resps     []*s2apb.SessionResp
    44  }
    45  
    46  func (f *fakeS2ASetupSessionServer) Send(resp *s2apb.SessionResp) error {
    47  	f.resps = append(f.resps, resp)
    48  	return nil
    49  }
    50  
    51  func (f *fakeS2ASetupSessionServer) Recv() (*s2apb.SessionReq, error) {
    52  	if f.recvCount == len(f.reqs) {
    53  		return nil, errors.New("request buffer was fully exhausted")
    54  	}
    55  	req := f.reqs[f.recvCount]
    56  	f.recvCount++
    57  	return req, nil
    58  }
    59  
    60  func TestSetupSession(t *testing.T) {
    61  	os.Setenv(accessTokenEnvVariable, "")
    62  	for _, tc := range []struct {
    63  		desc string
    64  		// Note that outResps[i] is the output for reqs[i].
    65  		reqs           []*s2apb.SessionReq
    66  		outResps       []*s2apb.SessionResp
    67  		hasNonOKStatus bool
    68  	}{
    69  		{
    70  			desc: "client failure no app protocols",
    71  			reqs: []*s2apb.SessionReq{
    72  				{
    73  					ReqOneof: &s2apb.SessionReq_ClientStart{
    74  						ClientStart: &s2apb.ClientSessionStartReq{},
    75  					},
    76  				},
    77  			},
    78  			hasNonOKStatus: true,
    79  		},
    80  		{
    81  			desc: "client failure non initial state",
    82  			reqs: []*s2apb.SessionReq{
    83  				{
    84  					ReqOneof: &s2apb.SessionReq_ClientStart{
    85  						ClientStart: &s2apb.ClientSessionStartReq{
    86  							ApplicationProtocols: []string{grpcAppProtocol},
    87  							MinTlsVersion:        commonpb.TLSVersion_TLS1_3,
    88  							MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
    89  							TlsCiphersuites: []commonpb.Ciphersuite{
    90  								commonpb.Ciphersuite_AES_128_GCM_SHA256,
    91  								commonpb.Ciphersuite_AES_256_GCM_SHA384,
    92  								commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
    93  							},
    94  						},
    95  					},
    96  				},
    97  				{
    98  					ReqOneof: &s2apb.SessionReq_ClientStart{
    99  						ClientStart: &s2apb.ClientSessionStartReq{
   100  							ApplicationProtocols: []string{grpcAppProtocol},
   101  							MinTlsVersion:        commonpb.TLSVersion_TLS1_3,
   102  							MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
   103  							TlsCiphersuites: []commonpb.Ciphersuite{
   104  								commonpb.Ciphersuite_AES_128_GCM_SHA256,
   105  								commonpb.Ciphersuite_AES_256_GCM_SHA384,
   106  								commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   107  							},
   108  						},
   109  					},
   110  				},
   111  			},
   112  			outResps: []*s2apb.SessionResp{
   113  				{
   114  					OutFrames: []byte(clientHelloFrame),
   115  					Status: &s2apb.SessionStatus{
   116  						Code: uint32(codes.OK),
   117  					},
   118  				},
   119  			},
   120  			hasNonOKStatus: true,
   121  		},
   122  		{
   123  			desc: "client test",
   124  			reqs: []*s2apb.SessionReq{
   125  				{
   126  					ReqOneof: &s2apb.SessionReq_ClientStart{
   127  						ClientStart: &s2apb.ClientSessionStartReq{
   128  							ApplicationProtocols: []string{grpcAppProtocol},
   129  							MinTlsVersion:        commonpb.TLSVersion_TLS1_3,
   130  							MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
   131  							TlsCiphersuites: []commonpb.Ciphersuite{
   132  								commonpb.Ciphersuite_AES_128_GCM_SHA256,
   133  								commonpb.Ciphersuite_AES_256_GCM_SHA384,
   134  								commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   135  							},
   136  							LocalIdentity: &commonpb.Identity{
   137  								IdentityOneof: &commonpb.Identity_Hostname{Hostname: "local hostname"},
   138  							},
   139  							TargetIdentities: []*commonpb.Identity{
   140  								{
   141  									IdentityOneof: &commonpb.Identity_SpiffeId{SpiffeId: "peer spiffe identity"},
   142  								},
   143  							},
   144  						},
   145  					},
   146  				},
   147  				{
   148  					ReqOneof: &s2apb.SessionReq_Next{
   149  						Next: &s2apb.SessionNextReq{
   150  							InBytes: []byte(serverFrame),
   151  						},
   152  					},
   153  				},
   154  			},
   155  			outResps: []*s2apb.SessionResp{
   156  				{
   157  					OutFrames: []byte(clientHelloFrame),
   158  					Status: &s2apb.SessionStatus{
   159  						Code: uint32(codes.OK),
   160  					},
   161  				},
   162  				{
   163  					OutFrames:     []byte(clientFinishedFrame),
   164  					BytesConsumed: uint32(len(serverFrame)),
   165  					Result: &s2apb.SessionResult{
   166  						ApplicationProtocol: grpcAppProtocol,
   167  						State: &s2apb.SessionState{
   168  							TlsVersion:     commonpb.TLSVersion_TLS1_3,
   169  							TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
   170  							InKey:          []byte(inKey),
   171  							OutKey:         []byte(outKey),
   172  						},
   173  						PeerIdentity: &commonpb.Identity{
   174  							IdentityOneof: &commonpb.Identity_SpiffeId{SpiffeId: "peer spiffe identity"},
   175  						},
   176  						LocalIdentity: &commonpb.Identity{
   177  							IdentityOneof: &commonpb.Identity_Hostname{Hostname: "local hostname"},
   178  						},
   179  					},
   180  					Status: &s2apb.SessionStatus{
   181  						Code: uint32(codes.OK),
   182  					},
   183  				},
   184  			},
   185  		},
   186  		{
   187  			desc: "server failure no app protocols",
   188  			reqs: []*s2apb.SessionReq{
   189  				{
   190  					ReqOneof: &s2apb.SessionReq_ServerStart{
   191  						ServerStart: &s2apb.ServerSessionStartReq{},
   192  					},
   193  				},
   194  			},
   195  			hasNonOKStatus: true,
   196  		},
   197  		{
   198  			desc: "server failure non initial state",
   199  			reqs: []*s2apb.SessionReq{
   200  				{
   201  					ReqOneof: &s2apb.SessionReq_ServerStart{
   202  						ServerStart: &s2apb.ServerSessionStartReq{
   203  							ApplicationProtocols: []string{grpcAppProtocol},
   204  							MinTlsVersion:        commonpb.TLSVersion_TLS1_3,
   205  							MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
   206  							TlsCiphersuites: []commonpb.Ciphersuite{
   207  								commonpb.Ciphersuite_AES_128_GCM_SHA256,
   208  								commonpb.Ciphersuite_AES_256_GCM_SHA384,
   209  								commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   210  							},
   211  						},
   212  					},
   213  				},
   214  				{
   215  					ReqOneof: &s2apb.SessionReq_ServerStart{
   216  						ServerStart: &s2apb.ServerSessionStartReq{
   217  							ApplicationProtocols: []string{grpcAppProtocol},
   218  							MinTlsVersion:        commonpb.TLSVersion_TLS1_3,
   219  							MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
   220  							TlsCiphersuites: []commonpb.Ciphersuite{
   221  								commonpb.Ciphersuite_AES_128_GCM_SHA256,
   222  								commonpb.Ciphersuite_AES_256_GCM_SHA384,
   223  								commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   224  							},
   225  						},
   226  					},
   227  				},
   228  			},
   229  			outResps: []*s2apb.SessionResp{
   230  				{
   231  					Status: &s2apb.SessionStatus{
   232  						Code: uint32(codes.OK),
   233  					},
   234  				},
   235  			},
   236  			hasNonOKStatus: true,
   237  		},
   238  		{
   239  			desc: "server test",
   240  			reqs: []*s2apb.SessionReq{
   241  				{
   242  					ReqOneof: &s2apb.SessionReq_ServerStart{
   243  						ServerStart: &s2apb.ServerSessionStartReq{
   244  							ApplicationProtocols: []string{grpcAppProtocol},
   245  							MinTlsVersion:        commonpb.TLSVersion_TLS1_3,
   246  							MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
   247  							TlsCiphersuites: []commonpb.Ciphersuite{
   248  								commonpb.Ciphersuite_AES_128_GCM_SHA256,
   249  								commonpb.Ciphersuite_AES_256_GCM_SHA384,
   250  								commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   251  							},
   252  							InBytes: []byte(clientHelloFrame),
   253  							LocalIdentities: []*commonpb.Identity{
   254  								{
   255  									IdentityOneof: &commonpb.Identity_Hostname{Hostname: "local hostname"},
   256  								},
   257  							},
   258  						},
   259  					},
   260  				},
   261  				{
   262  					ReqOneof: &s2apb.SessionReq_Next{
   263  						Next: &s2apb.SessionNextReq{
   264  							InBytes: []byte(clientFinishedFrame),
   265  						},
   266  					},
   267  				},
   268  			},
   269  			outResps: []*s2apb.SessionResp{
   270  				{
   271  					OutFrames:     []byte(serverFrame),
   272  					BytesConsumed: uint32(len(clientHelloFrame)),
   273  					Status: &s2apb.SessionStatus{
   274  						Code: uint32(codes.OK),
   275  					},
   276  				},
   277  				{
   278  					BytesConsumed: uint32(len(clientFinishedFrame)),
   279  					Result: &s2apb.SessionResult{
   280  						ApplicationProtocol: grpcAppProtocol,
   281  						State: &s2apb.SessionState{
   282  							TlsVersion:     commonpb.TLSVersion_TLS1_3,
   283  							TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
   284  							InKey:          []byte(inKey),
   285  							OutKey:         []byte(outKey),
   286  						},
   287  						LocalIdentity: &commonpb.Identity{
   288  							IdentityOneof: &commonpb.Identity_Hostname{Hostname: "local hostname"},
   289  						},
   290  					},
   291  					Status: &s2apb.SessionStatus{
   292  						Code: uint32(codes.OK),
   293  					},
   294  				},
   295  			},
   296  		},
   297  		{
   298  			desc: "resumption ticket test",
   299  			reqs: []*s2apb.SessionReq{
   300  				{
   301  					ReqOneof: &s2apb.SessionReq_ResumptionTicket{
   302  						ResumptionTicket: &s2apb.ResumptionTicketReq{
   303  							ConnectionId: 1234,
   304  							LocalIdentity: &commonpb.Identity{
   305  								IdentityOneof: &commonpb.Identity_Hostname{Hostname: "local hostname"},
   306  							},
   307  						},
   308  					},
   309  				},
   310  			},
   311  			outResps: []*s2apb.SessionResp{
   312  				{
   313  					Status: &s2apb.SessionStatus{
   314  						Code: uint32(codes.OK),
   315  					},
   316  				},
   317  			},
   318  			hasNonOKStatus: false,
   319  		},
   320  	} {
   321  		t.Run(tc.desc, func(t *testing.T) {
   322  			hs := FakeHandshakerService{}
   323  			stream := &fakeS2ASetupSessionServer{reqs: tc.reqs}
   324  			if got, want := hs.SetUpSession(stream) == nil, !tc.hasNonOKStatus; got != want {
   325  				t.Errorf("hs.SetUpSession(%v) = (err=nil) = %v, want %v", stream, got, want)
   326  			}
   327  			hasNonOKStatus := false
   328  			for i := range tc.reqs {
   329  				if stream.resps[i].GetStatus().GetCode() != uint32(codes.OK) {
   330  					hasNonOKStatus = true
   331  					break
   332  				}
   333  				if got, want := stream.resps[i], tc.outResps[i]; !cmp.Equal(got, want, protocmp.Transform()) {
   334  					t.Fatalf("stream.resps[%d] = %v, want %v", i, got, want)
   335  				}
   336  			}
   337  			if got, want := hasNonOKStatus, tc.hasNonOKStatus; got != want {
   338  				t.Errorf("hasNonOKStatus = %v, want %v", got, want)
   339  			}
   340  		})
   341  	}
   342  }
   343  
   344  func TestAuthenticateRequest(t *testing.T) {
   345  	for _, tc := range []struct {
   346  		description   string
   347  		acceptedToken string
   348  		request       *s2apb.SessionReq
   349  		expectedError string
   350  	}{
   351  		{
   352  			description: "access token env variable is not set",
   353  		},
   354  		{
   355  			description:   "request contains valid token",
   356  			acceptedToken: testAccessToken,
   357  			request: &s2apb.SessionReq{
   358  				AuthMechanisms: []*s2apb.AuthenticationMechanism{
   359  					{
   360  						MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   361  							Token: testAccessToken,
   362  						},
   363  					},
   364  				},
   365  			},
   366  		},
   367  		{
   368  			description:   "request contains invalid token",
   369  			acceptedToken: testAccessToken,
   370  			request: &s2apb.SessionReq{
   371  				AuthMechanisms: []*s2apb.AuthenticationMechanism{
   372  					{
   373  						MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   374  							Token: "bad_access_token",
   375  						},
   376  					},
   377  				},
   378  			},
   379  			expectedError: "received token: bad_access_token, expected token: test_access_token",
   380  		},
   381  		{
   382  			description:   "request contains valid and invalid tokens",
   383  			acceptedToken: testAccessToken,
   384  			request: &s2apb.SessionReq{
   385  				AuthMechanisms: []*s2apb.AuthenticationMechanism{
   386  					{
   387  						MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   388  							Token: testAccessToken,
   389  						},
   390  					},
   391  					{
   392  						MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   393  							Token: "bad_access_token",
   394  						},
   395  					},
   396  				},
   397  			},
   398  			expectedError: "received token: bad_access_token, expected token: test_access_token",
   399  		},
   400  	} {
   401  		t.Run(tc.description, func(t *testing.T) {
   402  			os.Setenv(accessTokenEnvVariable, tc.acceptedToken)
   403  			hs := &FakeHandshakerService{}
   404  			err := hs.authenticateRequest(tc.request)
   405  			if got, want := (err == nil), (tc.expectedError == ""); got != want {
   406  				t.Errorf("(err == nil): %t, (tc.expectedError == \"\"): %t", got, want)
   407  			}
   408  			if err != nil && !strings.Contains(err.Error(), tc.expectedError) {
   409  				t.Errorf("hs.authenticateRequest(%v)=%v, expected error to have substring: %v", tc.request, err, tc.expectedError)
   410  			}
   411  		})
   412  	}
   413  }
   414  

View as plain text