...

Source file src/google.golang.org/grpc/credentials/credentials_test.go

Documentation: google.golang.org/grpc/credentials

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package credentials
    20  
    21  import (
    22  	"context"
    23  	"crypto/tls"
    24  	"net"
    25  	"strings"
    26  	"testing"
    27  	"time"
    28  
    29  	"google.golang.org/grpc/internal/grpctest"
    30  	"google.golang.org/grpc/testdata"
    31  )
    32  
    33  const defaultTestTimeout = 10 * time.Second
    34  
    35  type s struct {
    36  	grpctest.Tester
    37  }
    38  
    39  func Test(t *testing.T) {
    40  	grpctest.RunSubTests(t, s{})
    41  }
    42  
    43  // A struct that implements AuthInfo interface but does not implement GetCommonAuthInfo() method.
    44  type testAuthInfoNoGetCommonAuthInfoMethod struct{}
    45  
    46  func (ta testAuthInfoNoGetCommonAuthInfoMethod) AuthType() string {
    47  	return "testAuthInfoNoGetCommonAuthInfoMethod"
    48  }
    49  
    50  // A struct that implements AuthInfo interface and implements CommonAuthInfo() method.
    51  type testAuthInfo struct {
    52  	CommonAuthInfo
    53  }
    54  
    55  func (ta testAuthInfo) AuthType() string {
    56  	return "testAuthInfo"
    57  }
    58  
    59  func (s) TestCheckSecurityLevel(t *testing.T) {
    60  	testCases := []struct {
    61  		authLevel SecurityLevel
    62  		testLevel SecurityLevel
    63  		want      bool
    64  	}{
    65  		{
    66  			authLevel: PrivacyAndIntegrity,
    67  			testLevel: PrivacyAndIntegrity,
    68  			want:      true,
    69  		},
    70  		{
    71  			authLevel: IntegrityOnly,
    72  			testLevel: PrivacyAndIntegrity,
    73  			want:      false,
    74  		},
    75  		{
    76  			authLevel: IntegrityOnly,
    77  			testLevel: NoSecurity,
    78  			want:      true,
    79  		},
    80  		{
    81  			authLevel: InvalidSecurityLevel,
    82  			testLevel: IntegrityOnly,
    83  			want:      true,
    84  		},
    85  		{
    86  			authLevel: InvalidSecurityLevel,
    87  			testLevel: PrivacyAndIntegrity,
    88  			want:      true,
    89  		},
    90  	}
    91  	for _, tc := range testCases {
    92  		err := CheckSecurityLevel(testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: tc.authLevel}}, tc.testLevel)
    93  		if tc.want && (err != nil) {
    94  			t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String())
    95  		} else if !tc.want && (err == nil) {
    96  			t.Fatalf("CheckSeurityLevel(%s, %s) returned success but want failure", tc.authLevel.String(), tc.testLevel.String())
    97  
    98  		}
    99  	}
   100  }
   101  
   102  func (s) TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) {
   103  	if err := CheckSecurityLevel(testAuthInfoNoGetCommonAuthInfoMethod{}, PrivacyAndIntegrity); err != nil {
   104  		t.Fatalf("CheckSeurityLevel() returned failure but want success")
   105  	}
   106  }
   107  
   108  func (s) TestTLSOverrideServerName(t *testing.T) {
   109  	expectedServerName := "server.name"
   110  	c := NewTLS(nil)
   111  	c.OverrideServerName(expectedServerName)
   112  	if c.Info().ServerName != expectedServerName {
   113  		t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
   114  	}
   115  }
   116  
   117  func (s) TestTLSClone(t *testing.T) {
   118  	expectedServerName := "server.name"
   119  	c := NewTLS(nil)
   120  	c.OverrideServerName(expectedServerName)
   121  	cc := c.Clone()
   122  	if cc.Info().ServerName != expectedServerName {
   123  		t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
   124  	}
   125  	cc.OverrideServerName("")
   126  	if c.Info().ServerName != expectedServerName {
   127  		t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
   128  	}
   129  
   130  }
   131  
   132  type serverHandshake func(net.Conn) (AuthInfo, error)
   133  
   134  func (s) TestClientHandshakeReturnsAuthInfo(t *testing.T) {
   135  	tcs := []struct {
   136  		name    string
   137  		address string
   138  	}{
   139  		{
   140  			name:    "localhost",
   141  			address: "localhost:0",
   142  		},
   143  		{
   144  			name:    "ipv4",
   145  			address: "127.0.0.1:0",
   146  		},
   147  		{
   148  			name:    "ipv6",
   149  			address: "[::1]:0",
   150  		},
   151  	}
   152  
   153  	for _, tc := range tcs {
   154  		t.Run(tc.name, func(t *testing.T) {
   155  			done := make(chan AuthInfo, 1)
   156  			lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address)
   157  			defer lis.Close()
   158  			lisAddr := lis.Addr().String()
   159  			clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
   160  			// wait until server sends serverAuthInfo or fails.
   161  			serverAuthInfo, ok := <-done
   162  			if !ok {
   163  				t.Fatalf("Error at server-side")
   164  			}
   165  			if !compare(clientAuthInfo, serverAuthInfo) {
   166  				t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
   167  			}
   168  		})
   169  	}
   170  }
   171  
   172  func (s) TestServerHandshakeReturnsAuthInfo(t *testing.T) {
   173  	done := make(chan AuthInfo, 1)
   174  	lis := launchServer(t, gRPCServerHandshake, done)
   175  	defer lis.Close()
   176  	clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String())
   177  	// wait until server sends serverAuthInfo or fails.
   178  	serverAuthInfo, ok := <-done
   179  	if !ok {
   180  		t.Fatalf("Error at server-side")
   181  	}
   182  	if !compare(clientAuthInfo, serverAuthInfo) {
   183  		t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo)
   184  	}
   185  }
   186  
   187  func (s) TestServerAndClientHandshake(t *testing.T) {
   188  	done := make(chan AuthInfo, 1)
   189  	lis := launchServer(t, gRPCServerHandshake, done)
   190  	defer lis.Close()
   191  	clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String())
   192  	// wait until server sends serverAuthInfo or fails.
   193  	serverAuthInfo, ok := <-done
   194  	if !ok {
   195  		t.Fatalf("Error at server-side")
   196  	}
   197  	if !compare(clientAuthInfo, serverAuthInfo) {
   198  		t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo)
   199  	}
   200  }
   201  
   202  func compare(a1, a2 AuthInfo) bool {
   203  	if a1.AuthType() != a2.AuthType() {
   204  		return false
   205  	}
   206  	switch a1.AuthType() {
   207  	case "tls":
   208  		state1 := a1.(TLSInfo).State
   209  		state2 := a2.(TLSInfo).State
   210  		if state1.Version == state2.Version &&
   211  			state1.HandshakeComplete == state2.HandshakeComplete &&
   212  			state1.CipherSuite == state2.CipherSuite &&
   213  			state1.NegotiatedProtocol == state2.NegotiatedProtocol {
   214  			return true
   215  		}
   216  		return false
   217  	default:
   218  		return false
   219  	}
   220  }
   221  
   222  func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener {
   223  	return launchServerOnListenAddress(t, hs, done, "localhost:0")
   224  }
   225  
   226  func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan AuthInfo, address string) net.Listener {
   227  	lis, err := net.Listen("tcp", address)
   228  	if err != nil {
   229  		if strings.Contains(err.Error(), "bind: cannot assign requested address") ||
   230  			strings.Contains(err.Error(), "socket: address family not supported by protocol") {
   231  			t.Skipf("no support for address %v", address)
   232  		}
   233  		t.Fatalf("Failed to listen: %v", err)
   234  	}
   235  	go serverHandle(t, hs, done, lis)
   236  	return lis
   237  }
   238  
   239  // Is run in a separate goroutine.
   240  func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) {
   241  	serverRawConn, err := lis.Accept()
   242  	if err != nil {
   243  		t.Errorf("Server failed to accept connection: %v", err)
   244  		close(done)
   245  		return
   246  	}
   247  	serverAuthInfo, err := hs(serverRawConn)
   248  	if err != nil {
   249  		t.Errorf("Server failed while handshake. Error: %v", err)
   250  		serverRawConn.Close()
   251  		close(done)
   252  		return
   253  	}
   254  	done <- serverAuthInfo
   255  }
   256  
   257  func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo {
   258  	conn, err := net.Dial("tcp", lisAddr)
   259  	if err != nil {
   260  		t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
   261  	}
   262  	defer conn.Close()
   263  	clientAuthInfo, err := hs(conn, lisAddr)
   264  	if err != nil {
   265  		t.Fatalf("Error on client while handshake. Error: %v", err)
   266  	}
   267  	return clientAuthInfo
   268  }
   269  
   270  // Server handshake implementation in gRPC.
   271  func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) {
   272  	serverTLS, err := NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   273  	if err != nil {
   274  		return nil, err
   275  	}
   276  	_, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
   277  	if err != nil {
   278  		return nil, err
   279  	}
   280  	return serverAuthInfo, nil
   281  }
   282  
   283  // Client handshake implementation in gRPC.
   284  func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) {
   285  	clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true})
   286  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   287  	defer cancel()
   288  	_, authInfo, err := clientTLS.ClientHandshake(ctx, lisAddr, conn)
   289  	if err != nil {
   290  		return nil, err
   291  	}
   292  	return authInfo, nil
   293  }
   294  
   295  func tlsServerHandshake(conn net.Conn) (AuthInfo, error) {
   296  	cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   297  	if err != nil {
   298  		return nil, err
   299  	}
   300  	serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
   301  	serverConn := tls.Server(conn, serverTLSConfig)
   302  	err = serverConn.Handshake()
   303  	if err != nil {
   304  		return nil, err
   305  	}
   306  	return TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil
   307  }
   308  
   309  func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
   310  	clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
   311  	clientConn := tls.Client(conn, clientTLSConfig)
   312  	if err := clientConn.Handshake(); err != nil {
   313  		return nil, err
   314  	}
   315  	return TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil
   316  }
   317  

View as plain text