...

Source file src/github.com/lib/pq/ssl_test.go

Documentation: github.com/lib/pq

     1  package pq
     2  
     3  // This file contains SSL tests
     4  
     5  import (
     6  	"bytes"
     7  	_ "crypto/sha256"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"database/sql"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"os"
    15  	"path/filepath"
    16  	"strings"
    17  	"testing"
    18  	"time"
    19  )
    20  
    21  func maybeSkipSSLTests(t *testing.T) {
    22  	// Require some special variables for testing certificates
    23  	if os.Getenv("PQSSLCERTTEST_PATH") == "" {
    24  		t.Skip("PQSSLCERTTEST_PATH not set, skipping SSL tests")
    25  	}
    26  
    27  	value := os.Getenv("PQGOSSLTESTS")
    28  	if value == "" || value == "0" {
    29  		t.Skip("PQGOSSLTESTS not enabled, skipping SSL tests")
    30  	} else if value != "1" {
    31  		t.Fatalf("unexpected value %q for PQGOSSLTESTS", value)
    32  	}
    33  }
    34  
    35  func openSSLConn(t *testing.T, conninfo string) (*sql.DB, error) {
    36  	db, err := openTestConnConninfo(conninfo)
    37  	if err != nil {
    38  		// should never fail
    39  		t.Fatal(err)
    40  	}
    41  	// Do something with the connection to see whether it's working or not.
    42  	tx, err := db.Begin()
    43  	if err == nil {
    44  		return db, tx.Rollback()
    45  	}
    46  	_ = db.Close()
    47  	return nil, err
    48  }
    49  
    50  func checkSSLSetup(t *testing.T, conninfo string) {
    51  	_, err := openSSLConn(t, conninfo)
    52  	if pge, ok := err.(*Error); ok {
    53  		if pge.Code.Name() != "invalid_authorization_specification" {
    54  			t.Fatalf("unexpected error code '%s'", pge.Code.Name())
    55  		}
    56  	} else {
    57  		t.Fatalf("expected %T, got %v", (*Error)(nil), err)
    58  	}
    59  }
    60  
    61  // Connect over SSL and run a simple query to test the basics
    62  func TestSSLConnection(t *testing.T) {
    63  	maybeSkipSSLTests(t)
    64  	// Environment sanity check: should fail without SSL
    65  	checkSSLSetup(t, "sslmode=disable user=pqgossltest")
    66  
    67  	db, err := openSSLConn(t, "sslmode=require user=pqgossltest")
    68  	if err != nil {
    69  		t.Fatal(err)
    70  	}
    71  	rows, err := db.Query("SELECT 1")
    72  	if err != nil {
    73  		t.Fatal(err)
    74  	}
    75  	rows.Close()
    76  }
    77  
    78  // Test sslmode=verify-full
    79  func TestSSLVerifyFull(t *testing.T) {
    80  	maybeSkipSSLTests(t)
    81  	// Environment sanity check: should fail without SSL
    82  	checkSSLSetup(t, "sslmode=disable user=pqgossltest")
    83  
    84  	// Not OK according to the system CA
    85  	_, err := openSSLConn(t, "host=postgres sslmode=verify-full user=pqgossltest")
    86  	if err == nil {
    87  		t.Fatal("expected error")
    88  	}
    89  	_, ok := err.(x509.UnknownAuthorityError)
    90  	if !ok {
    91  		_, ok := err.(x509.HostnameError)
    92  		if !ok {
    93  			t.Fatalf("expected x509.UnknownAuthorityError or x509.HostnameError, got %#+v", err)
    94  		}
    95  	}
    96  
    97  	rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt")
    98  	rootCert := "sslrootcert=" + rootCertPath + " "
    99  	// No match on Common Name
   100  	_, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-full user=pqgossltest")
   101  	if err == nil {
   102  		t.Fatal("expected error")
   103  	}
   104  	_, ok = err.(x509.HostnameError)
   105  	if !ok {
   106  		t.Fatalf("expected x509.HostnameError, got %#+v", err)
   107  	}
   108  	// OK
   109  	_, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-full user=pqgossltest")
   110  	if err != nil {
   111  		t.Fatal(err)
   112  	}
   113  }
   114  
   115  // Test sslmode=require sslrootcert=rootCertPath
   116  func TestSSLRequireWithRootCert(t *testing.T) {
   117  	maybeSkipSSLTests(t)
   118  	// Environment sanity check: should fail without SSL
   119  	checkSSLSetup(t, "sslmode=disable user=pqgossltest")
   120  
   121  	bogusRootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "bogus_root.crt")
   122  	bogusRootCert := "sslrootcert=" + bogusRootCertPath + " "
   123  
   124  	// Not OK according to the bogus CA
   125  	_, err := openSSLConn(t, bogusRootCert+"host=postgres sslmode=require user=pqgossltest")
   126  	if err == nil {
   127  		t.Fatal("expected error")
   128  	}
   129  	_, ok := err.(x509.UnknownAuthorityError)
   130  	if !ok {
   131  		t.Fatalf("expected x509.UnknownAuthorityError, got %s, %#+v", err, err)
   132  	}
   133  
   134  	nonExistentCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "non_existent.crt")
   135  	nonExistentCert := "sslrootcert=" + nonExistentCertPath + " "
   136  
   137  	// No match on Common Name, but that's OK because we're not validating anything.
   138  	_, err = openSSLConn(t, nonExistentCert+"host=127.0.0.1 sslmode=require user=pqgossltest")
   139  	if err != nil {
   140  		t.Fatal(err)
   141  	}
   142  
   143  	rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt")
   144  	rootCert := "sslrootcert=" + rootCertPath + " "
   145  
   146  	// No match on Common Name, but that's OK because we're not validating the CN.
   147  	_, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=require user=pqgossltest")
   148  	if err != nil {
   149  		t.Fatal(err)
   150  	}
   151  	// Everything OK
   152  	_, err = openSSLConn(t, rootCert+"host=postgres sslmode=require user=pqgossltest")
   153  	if err != nil {
   154  		t.Fatal(err)
   155  	}
   156  }
   157  
   158  // Test sslmode=verify-ca
   159  func TestSSLVerifyCA(t *testing.T) {
   160  	maybeSkipSSLTests(t)
   161  	// Environment sanity check: should fail without SSL
   162  	checkSSLSetup(t, "sslmode=disable user=pqgossltest")
   163  
   164  	// Not OK according to the system CA
   165  	{
   166  		_, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest")
   167  		if _, ok := err.(x509.UnknownAuthorityError); !ok {
   168  			t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err)
   169  		}
   170  	}
   171  
   172  	// Still not OK according to the system CA; empty sslrootcert is treated as unspecified.
   173  	{
   174  		_, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest sslrootcert=''")
   175  		if _, ok := err.(x509.UnknownAuthorityError); !ok {
   176  			t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err)
   177  		}
   178  	}
   179  
   180  	rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt")
   181  	rootCert := "sslrootcert=" + rootCertPath + " "
   182  	// No match on Common Name, but that's OK
   183  	if _, err := openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest"); err != nil {
   184  		t.Fatal(err)
   185  	}
   186  	// Everything OK
   187  	if _, err := openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest"); err != nil {
   188  		t.Fatal(err)
   189  	}
   190  }
   191  
   192  // Authenticate over SSL using client certificates
   193  func TestSSLClientCertificates(t *testing.T) {
   194  	maybeSkipSSLTests(t)
   195  	// Environment sanity check: should fail without SSL
   196  	checkSSLSetup(t, "sslmode=disable user=pqgossltest")
   197  
   198  	const baseinfo = "sslmode=require user=pqgosslcert"
   199  
   200  	// Certificate not specified, should fail
   201  	{
   202  		_, err := openSSLConn(t, baseinfo)
   203  		if pge, ok := err.(*Error); ok {
   204  			if pge.Code.Name() != "invalid_authorization_specification" {
   205  				t.Fatalf("unexpected error code '%s'", pge.Code.Name())
   206  			}
   207  		} else {
   208  			t.Fatalf("expected %T, got %v", (*Error)(nil), err)
   209  		}
   210  	}
   211  
   212  	// Empty certificate specified, should fail
   213  	{
   214  		_, err := openSSLConn(t, baseinfo+" sslcert=''")
   215  		if pge, ok := err.(*Error); ok {
   216  			if pge.Code.Name() != "invalid_authorization_specification" {
   217  				t.Fatalf("unexpected error code '%s'", pge.Code.Name())
   218  			}
   219  		} else {
   220  			t.Fatalf("expected %T, got %v", (*Error)(nil), err)
   221  		}
   222  	}
   223  
   224  	// Non-existent certificate specified, should fail
   225  	{
   226  		_, err := openSSLConn(t, baseinfo+" sslcert=/tmp/filedoesnotexist")
   227  		if pge, ok := err.(*Error); ok {
   228  			if pge.Code.Name() != "invalid_authorization_specification" {
   229  				t.Fatalf("unexpected error code '%s'", pge.Code.Name())
   230  			}
   231  		} else {
   232  			t.Fatalf("expected %T, got %v", (*Error)(nil), err)
   233  		}
   234  	}
   235  
   236  	certpath, ok := os.LookupEnv("PQSSLCERTTEST_PATH")
   237  	if !ok {
   238  		t.Fatalf("PQSSLCERTTEST_PATH not present in environment")
   239  	}
   240  
   241  	sslcert := filepath.Join(certpath, "postgresql.crt")
   242  
   243  	// Cert present, key not specified, should fail
   244  	{
   245  		_, err := openSSLConn(t, baseinfo+" sslcert="+sslcert)
   246  		if _, ok := err.(*os.PathError); !ok {
   247  			t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err)
   248  		}
   249  	}
   250  
   251  	// Cert present, empty key specified, should fail
   252  	{
   253  		_, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=''")
   254  		if _, ok := err.(*os.PathError); !ok {
   255  			t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err)
   256  		}
   257  	}
   258  
   259  	// Cert present, non-existent key, should fail
   260  	{
   261  		_, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=/tmp/filedoesnotexist")
   262  		if _, ok := err.(*os.PathError); !ok {
   263  			t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err)
   264  		}
   265  	}
   266  
   267  	// Key has wrong permissions (passing the cert as the key), should fail
   268  	if _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslcert); err != ErrSSLKeyHasWorldPermissions {
   269  		t.Fatalf("expected %s, got %#+v", ErrSSLKeyHasWorldPermissions, err)
   270  	}
   271  
   272  	sslkey := filepath.Join(certpath, "postgresql.key")
   273  
   274  	// Should work
   275  	if db, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslkey); err != nil {
   276  		t.Fatal(err)
   277  	} else {
   278  		rows, err := db.Query("SELECT 1")
   279  		if err != nil {
   280  			t.Fatal(err)
   281  		}
   282  		if err := rows.Close(); err != nil {
   283  			t.Fatal(err)
   284  		}
   285  		if err := db.Close(); err != nil {
   286  			t.Fatal(err)
   287  		}
   288  	}
   289  }
   290  
   291  // Check that clint sends SNI data when `sslsni` is not disabled
   292  func TestSNISupport(t *testing.T) {
   293  	t.Parallel()
   294  	tests := []struct {
   295  		name         string
   296  		conn_param   string
   297  		hostname     string
   298  		expected_sni string
   299  	}{
   300  		{
   301  			name:         "SNI is set by default",
   302  			conn_param:   "",
   303  			hostname:     "localhost",
   304  			expected_sni: "localhost",
   305  		},
   306  		{
   307  			name:         "SNI is passed when asked for",
   308  			conn_param:   "sslsni=1",
   309  			hostname:     "localhost",
   310  			expected_sni: "localhost",
   311  		},
   312  		{
   313  			name:         "SNI is not passed when disabled",
   314  			conn_param:   "sslsni=0",
   315  			hostname:     "localhost",
   316  			expected_sni: "",
   317  		},
   318  		{
   319  			name:         "SNI is not set for IPv4",
   320  			conn_param:   "",
   321  			hostname:     "127.0.0.1",
   322  			expected_sni: "",
   323  		},
   324  	}
   325  	for _, tt := range tests {
   326  		tt := tt
   327  		t.Run(tt.name, func(t *testing.T) {
   328  			t.Parallel()
   329  
   330  			// Start mock postgres server on OS-provided port
   331  			listener, err := net.Listen("tcp", "127.0.0.1:")
   332  			if err != nil {
   333  				t.Fatal(err)
   334  			}
   335  			serverErrChan := make(chan error, 1)
   336  			serverSNINameChan := make(chan string, 1)
   337  			go mockPostgresSSL(listener, serverErrChan, serverSNINameChan)
   338  
   339  			defer listener.Close()
   340  			defer close(serverErrChan)
   341  			defer close(serverSNINameChan)
   342  
   343  			// Try to establish a connection with the mock server. Connection will error out after TLS
   344  			// clientHello, but it is enough to catch SNI data on the server side
   345  			port := strings.Split(listener.Addr().String(), ":")[1]
   346  			connStr := fmt.Sprintf("sslmode=require host=%s port=%s %s", tt.hostname, port, tt.conn_param)
   347  
   348  			// We are okay to skip this error as we are polling serverErrChan and we'll get an error
   349  			// or timeout from the server side in case of problems here.
   350  			db, _ := sql.Open("postgres", connStr)
   351  			_, _ = db.Exec("SELECT 1")
   352  
   353  			// Check SNI data
   354  			select {
   355  			case sniHost := <-serverSNINameChan:
   356  				if sniHost != tt.expected_sni {
   357  					t.Fatalf("Expected SNI to be 'localhost', got '%+v' instead", sniHost)
   358  				}
   359  			case err = <-serverErrChan:
   360  				t.Fatalf("mock server failed with error: %+v", err)
   361  			case <-time.After(time.Second):
   362  				t.Fatal("exceeded connection timeout without erroring out")
   363  			}
   364  		})
   365  	}
   366  }
   367  
   368  // Make a postgres mock server to test TLS SNI
   369  //
   370  // Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.
   371  // While reading clientHello catch passed SNI data and report it to nameChan.
   372  func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) {
   373  	var sniHost string
   374  
   375  	conn, err := listener.Accept()
   376  	if err != nil {
   377  		errChan <- err
   378  		return
   379  	}
   380  	defer conn.Close()
   381  
   382  	err = conn.SetDeadline(time.Now().Add(time.Second))
   383  	if err != nil {
   384  		errChan <- err
   385  		return
   386  	}
   387  
   388  	// Receive StartupMessage with SSL Request
   389  	startupMessage := make([]byte, 8)
   390  	if _, err := io.ReadFull(conn, startupMessage); err != nil {
   391  		errChan <- err
   392  		return
   393  	}
   394  	// StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
   395  	if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) {
   396  		errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
   397  		return
   398  	}
   399  
   400  	// Respond with SSLOk
   401  	_, err = conn.Write([]byte("S"))
   402  	if err != nil {
   403  		errChan <- err
   404  		return
   405  	}
   406  
   407  	// Set up TLS context to catch clientHello. It will always error out during handshake
   408  	// as no certificate is set.
   409  	srv := tls.Server(conn, &tls.Config{
   410  		GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
   411  			sniHost = argHello.ServerName
   412  			return nil, nil
   413  		},
   414  	})
   415  	defer srv.Close()
   416  
   417  	// Do the TLS handshake ignoring errors
   418  	_ = srv.Handshake()
   419  
   420  	nameChan <- sniHost
   421  }
   422  

View as plain text