...

Source file src/github.com/jackc/pgconn/config.go

Documentation: github.com/jackc/pgconn

     1  package pgconn
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/pem"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"math"
    13  	"net"
    14  	"net/url"
    15  	"os"
    16  	"path/filepath"
    17  	"strconv"
    18  	"strings"
    19  	"time"
    20  
    21  	"github.com/jackc/chunkreader/v2"
    22  	"github.com/jackc/pgpassfile"
    23  	"github.com/jackc/pgproto3/v2"
    24  	"github.com/jackc/pgservicefile"
    25  )
    26  
    27  type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
    28  type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
    29  type GetSSLPasswordFunc func(ctx context.Context) string
    30  
    31  // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
    32  // manually initialized Config will cause ConnectConfig to panic.
    33  type Config struct {
    34  	Host           string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
    35  	Port           uint16
    36  	Database       string
    37  	User           string
    38  	Password       string
    39  	TLSConfig      *tls.Config // nil disables TLS
    40  	ConnectTimeout time.Duration
    41  	DialFunc       DialFunc   // e.g. net.Dialer.DialContext
    42  	LookupFunc     LookupFunc // e.g. net.Resolver.LookupHost
    43  	BuildFrontend  BuildFrontendFunc
    44  	RuntimeParams  map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
    45  
    46  	KerberosSrvName string
    47  	KerberosSpn     string
    48  	Fallbacks       []*FallbackConfig
    49  
    50  	// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
    51  	// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
    52  	// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
    53  	ValidateConnect ValidateConnectFunc
    54  
    55  	// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
    56  	// or prepare statements). If this returns an error the connection attempt fails.
    57  	AfterConnect AfterConnectFunc
    58  
    59  	// OnNotice is a callback function called when a notice response is received.
    60  	OnNotice NoticeHandler
    61  
    62  	// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
    63  	OnNotification NotificationHandler
    64  
    65  	createdByParseConfig bool // Used to enforce created by ParseConfig rule.
    66  }
    67  
    68  // ParseConfigOptions contains options that control how a config is built such as getsslpassword.
    69  type ParseConfigOptions struct {
    70  	// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
    71  	// PQsetSSLKeyPassHook_OpenSSL.
    72  	GetSSLPassword GetSSLPasswordFunc
    73  }
    74  
    75  // Copy returns a deep copy of the config that is safe to use and modify.
    76  // The only exception is the TLSConfig field:
    77  // according to the tls.Config docs it must not be modified after creation.
    78  func (c *Config) Copy() *Config {
    79  	newConf := new(Config)
    80  	*newConf = *c
    81  	if newConf.TLSConfig != nil {
    82  		newConf.TLSConfig = c.TLSConfig.Clone()
    83  	}
    84  	if newConf.RuntimeParams != nil {
    85  		newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
    86  		for k, v := range c.RuntimeParams {
    87  			newConf.RuntimeParams[k] = v
    88  		}
    89  	}
    90  	if newConf.Fallbacks != nil {
    91  		newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
    92  		for i, fallback := range c.Fallbacks {
    93  			newFallback := new(FallbackConfig)
    94  			*newFallback = *fallback
    95  			if newFallback.TLSConfig != nil {
    96  				newFallback.TLSConfig = fallback.TLSConfig.Clone()
    97  			}
    98  			newConf.Fallbacks[i] = newFallback
    99  		}
   100  	}
   101  	return newConf
   102  }
   103  
   104  // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
   105  // network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
   106  type FallbackConfig struct {
   107  	Host      string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
   108  	Port      uint16
   109  	TLSConfig *tls.Config // nil disables TLS
   110  }
   111  
   112  // isAbsolutePath checks if the provided value is an absolute path either
   113  // beginning with a forward slash (as on Linux-based systems) or with a capital
   114  // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
   115  func isAbsolutePath(path string) bool {
   116  	isWindowsPath := func(p string) bool {
   117  		if len(p) < 3 {
   118  			return false
   119  		}
   120  		drive := p[0]
   121  		colon := p[1]
   122  		backslash := p[2]
   123  		if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
   124  			return true
   125  		}
   126  		return false
   127  	}
   128  	return strings.HasPrefix(path, "/") || isWindowsPath(path)
   129  }
   130  
   131  // NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
   132  // net.Dial.
   133  func NetworkAddress(host string, port uint16) (network, address string) {
   134  	if isAbsolutePath(host) {
   135  		network = "unix"
   136  		address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
   137  	} else {
   138  		network = "tcp"
   139  		address = net.JoinHostPort(host, strconv.Itoa(int(port)))
   140  	}
   141  	return network, address
   142  }
   143  
   144  // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
   145  // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
   146  // matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
   147  // See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
   148  // empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
   149  //
   150  //   # Example DSN
   151  //   user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
   152  //
   153  //   # Example URL
   154  //   postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
   155  //
   156  // The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
   157  // through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
   158  // interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
   159  // not be modified individually. They should all be modified or all left unchanged.
   160  //
   161  // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
   162  // values that will be tried in order. This can be used as part of a high availability system. See
   163  // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
   164  //
   165  //   # Example URL
   166  //   postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
   167  //
   168  // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
   169  // via database URL or DSN:
   170  //
   171  //   PGHOST
   172  //   PGPORT
   173  //   PGDATABASE
   174  //   PGUSER
   175  //   PGPASSWORD
   176  //   PGPASSFILE
   177  //   PGSERVICE
   178  //   PGSERVICEFILE
   179  //   PGSSLMODE
   180  //   PGSSLCERT
   181  //   PGSSLKEY
   182  //   PGSSLROOTCERT
   183  //   PGSSLPASSWORD
   184  //   PGAPPNAME
   185  //   PGCONNECT_TIMEOUT
   186  //   PGTARGETSESSIONATTRS
   187  //
   188  // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
   189  //
   190  // See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
   191  // usually but not always the environment variable name downcased and without the "PG" prefix.
   192  //
   193  // Important Security Notes:
   194  //
   195  // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
   196  // not set.
   197  //
   198  // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
   199  // security each sslmode provides.
   200  //
   201  // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
   202  // the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
   203  // sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
   204  // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
   205  // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
   206  // TLSConfig.
   207  //
   208  // Other known differences with libpq:
   209  //
   210  // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
   211  // does not.
   212  //
   213  // In addition, ParseConfig accepts the following options:
   214  //
   215  //  min_read_buffer_size
   216  //    The minimum size of the internal read buffer. Default 8192.
   217  //  servicefile
   218  //    libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
   219  //    part of the connection string.
   220  func ParseConfig(connString string) (*Config, error) {
   221  	var parseConfigOptions ParseConfigOptions
   222  	return ParseConfigWithOptions(connString, parseConfigOptions)
   223  }
   224  
   225  // ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
   226  // C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
   227  // get the SSL password.
   228  func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
   229  	defaultSettings := defaultSettings()
   230  	envSettings := parseEnvSettings()
   231  
   232  	connStringSettings := make(map[string]string)
   233  	if connString != "" {
   234  		var err error
   235  		// connString may be a database URL or a DSN
   236  		if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
   237  			connStringSettings, err = parseURLSettings(connString)
   238  			if err != nil {
   239  				return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
   240  			}
   241  		} else {
   242  			connStringSettings, err = parseDSNSettings(connString)
   243  			if err != nil {
   244  				return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
   245  			}
   246  		}
   247  	}
   248  
   249  	settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
   250  	if service, present := settings["service"]; present {
   251  		serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
   252  		if err != nil {
   253  			return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
   254  		}
   255  
   256  		settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
   257  	}
   258  
   259  	minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
   260  	if err != nil {
   261  		return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
   262  	}
   263  
   264  	config := &Config{
   265  		createdByParseConfig: true,
   266  		Database:             settings["database"],
   267  		User:                 settings["user"],
   268  		Password:             settings["password"],
   269  		RuntimeParams:        make(map[string]string),
   270  		BuildFrontend:        makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
   271  	}
   272  
   273  	if connectTimeoutSetting, present := settings["connect_timeout"]; present {
   274  		connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
   275  		if err != nil {
   276  			return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
   277  		}
   278  		config.ConnectTimeout = connectTimeout
   279  		config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
   280  	} else {
   281  		defaultDialer := makeDefaultDialer()
   282  		config.DialFunc = defaultDialer.DialContext
   283  	}
   284  
   285  	config.LookupFunc = makeDefaultResolver().LookupHost
   286  
   287  	notRuntimeParams := map[string]struct{}{
   288  		"host":                 {},
   289  		"port":                 {},
   290  		"database":             {},
   291  		"user":                 {},
   292  		"password":             {},
   293  		"passfile":             {},
   294  		"connect_timeout":      {},
   295  		"sslmode":              {},
   296  		"sslkey":               {},
   297  		"sslcert":              {},
   298  		"sslrootcert":          {},
   299  		"sslpassword":          {},
   300  		"sslsni":               {},
   301  		"krbspn":               {},
   302  		"krbsrvname":           {},
   303  		"target_session_attrs": {},
   304  		"min_read_buffer_size": {},
   305  		"service":              {},
   306  		"servicefile":          {},
   307  	}
   308  
   309  	// Adding kerberos configuration
   310  	if _, present := settings["krbsrvname"]; present {
   311  		config.KerberosSrvName = settings["krbsrvname"]
   312  	}
   313  	if _, present := settings["krbspn"]; present {
   314  		config.KerberosSpn = settings["krbspn"]
   315  	}
   316  
   317  	for k, v := range settings {
   318  		if _, present := notRuntimeParams[k]; present {
   319  			continue
   320  		}
   321  		config.RuntimeParams[k] = v
   322  	}
   323  
   324  	fallbacks := []*FallbackConfig{}
   325  
   326  	hosts := strings.Split(settings["host"], ",")
   327  	ports := strings.Split(settings["port"], ",")
   328  
   329  	for i, host := range hosts {
   330  		var portStr string
   331  		if i < len(ports) {
   332  			portStr = ports[i]
   333  		} else {
   334  			portStr = ports[0]
   335  		}
   336  
   337  		port, err := parsePort(portStr)
   338  		if err != nil {
   339  			return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
   340  		}
   341  
   342  		var tlsConfigs []*tls.Config
   343  
   344  		// Ignore TLS settings if Unix domain socket like libpq
   345  		if network, _ := NetworkAddress(host, port); network == "unix" {
   346  			tlsConfigs = append(tlsConfigs, nil)
   347  		} else {
   348  			var err error
   349  			tlsConfigs, err = configTLS(settings, host, options)
   350  			if err != nil {
   351  				return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
   352  			}
   353  		}
   354  
   355  		for _, tlsConfig := range tlsConfigs {
   356  			fallbacks = append(fallbacks, &FallbackConfig{
   357  				Host:      host,
   358  				Port:      port,
   359  				TLSConfig: tlsConfig,
   360  			})
   361  		}
   362  	}
   363  
   364  	config.Host = fallbacks[0].Host
   365  	config.Port = fallbacks[0].Port
   366  	config.TLSConfig = fallbacks[0].TLSConfig
   367  	config.Fallbacks = fallbacks[1:]
   368  
   369  	if config.Password == "" {
   370  		passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
   371  		if err == nil {
   372  			host := config.Host
   373  			if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
   374  				host = "localhost"
   375  			}
   376  
   377  			config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
   378  		}
   379  	}
   380  
   381  	switch tsa := settings["target_session_attrs"]; tsa {
   382  	case "read-write":
   383  		config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
   384  	case "read-only":
   385  		config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
   386  	case "primary":
   387  		config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
   388  	case "standby":
   389  		config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
   390  	case "prefer-standby":
   391  		config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
   392  	case "any":
   393  		// do nothing
   394  	default:
   395  		return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
   396  	}
   397  
   398  	return config, nil
   399  }
   400  
   401  func mergeSettings(settingSets ...map[string]string) map[string]string {
   402  	settings := make(map[string]string)
   403  
   404  	for _, s2 := range settingSets {
   405  		for k, v := range s2 {
   406  			settings[k] = v
   407  		}
   408  	}
   409  
   410  	return settings
   411  }
   412  
   413  func parseEnvSettings() map[string]string {
   414  	settings := make(map[string]string)
   415  
   416  	nameMap := map[string]string{
   417  		"PGHOST":               "host",
   418  		"PGPORT":               "port",
   419  		"PGDATABASE":           "database",
   420  		"PGUSER":               "user",
   421  		"PGPASSWORD":           "password",
   422  		"PGPASSFILE":           "passfile",
   423  		"PGAPPNAME":            "application_name",
   424  		"PGCONNECT_TIMEOUT":    "connect_timeout",
   425  		"PGSSLMODE":            "sslmode",
   426  		"PGSSLKEY":             "sslkey",
   427  		"PGSSLCERT":            "sslcert",
   428  		"PGSSLSNI":             "sslsni",
   429  		"PGSSLROOTCERT":        "sslrootcert",
   430  		"PGSSLPASSWORD":        "sslpassword",
   431  		"PGTARGETSESSIONATTRS": "target_session_attrs",
   432  		"PGSERVICE":            "service",
   433  		"PGSERVICEFILE":        "servicefile",
   434  	}
   435  
   436  	for envname, realname := range nameMap {
   437  		value := os.Getenv(envname)
   438  		if value != "" {
   439  			settings[realname] = value
   440  		}
   441  	}
   442  
   443  	return settings
   444  }
   445  
   446  func parseURLSettings(connString string) (map[string]string, error) {
   447  	settings := make(map[string]string)
   448  
   449  	url, err := url.Parse(connString)
   450  	if err != nil {
   451  		return nil, err
   452  	}
   453  
   454  	if url.User != nil {
   455  		settings["user"] = url.User.Username()
   456  		if password, present := url.User.Password(); present {
   457  			settings["password"] = password
   458  		}
   459  	}
   460  
   461  	// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
   462  	var hosts []string
   463  	var ports []string
   464  	for _, host := range strings.Split(url.Host, ",") {
   465  		if host == "" {
   466  			continue
   467  		}
   468  		if isIPOnly(host) {
   469  			hosts = append(hosts, strings.Trim(host, "[]"))
   470  			continue
   471  		}
   472  		h, p, err := net.SplitHostPort(host)
   473  		if err != nil {
   474  			return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
   475  		}
   476  		if h != "" {
   477  			hosts = append(hosts, h)
   478  		}
   479  		if p != "" {
   480  			ports = append(ports, p)
   481  		}
   482  	}
   483  	if len(hosts) > 0 {
   484  		settings["host"] = strings.Join(hosts, ",")
   485  	}
   486  	if len(ports) > 0 {
   487  		settings["port"] = strings.Join(ports, ",")
   488  	}
   489  
   490  	database := strings.TrimLeft(url.Path, "/")
   491  	if database != "" {
   492  		settings["database"] = database
   493  	}
   494  
   495  	nameMap := map[string]string{
   496  		"dbname": "database",
   497  	}
   498  
   499  	for k, v := range url.Query() {
   500  		if k2, present := nameMap[k]; present {
   501  			k = k2
   502  		}
   503  
   504  		settings[k] = v[0]
   505  	}
   506  
   507  	return settings, nil
   508  }
   509  
   510  func isIPOnly(host string) bool {
   511  	return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
   512  }
   513  
   514  var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
   515  
   516  func parseDSNSettings(s string) (map[string]string, error) {
   517  	settings := make(map[string]string)
   518  
   519  	nameMap := map[string]string{
   520  		"dbname": "database",
   521  	}
   522  
   523  	for len(s) > 0 {
   524  		var key, val string
   525  		eqIdx := strings.IndexRune(s, '=')
   526  		if eqIdx < 0 {
   527  			return nil, errors.New("invalid dsn")
   528  		}
   529  
   530  		key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
   531  		s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
   532  		if len(s) == 0 {
   533  		} else if s[0] != '\'' {
   534  			end := 0
   535  			for ; end < len(s); end++ {
   536  				if asciiSpace[s[end]] == 1 {
   537  					break
   538  				}
   539  				if s[end] == '\\' {
   540  					end++
   541  					if end == len(s) {
   542  						return nil, errors.New("invalid backslash")
   543  					}
   544  				}
   545  			}
   546  			val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
   547  			if end == len(s) {
   548  				s = ""
   549  			} else {
   550  				s = s[end+1:]
   551  			}
   552  		} else { // quoted string
   553  			s = s[1:]
   554  			end := 0
   555  			for ; end < len(s); end++ {
   556  				if s[end] == '\'' {
   557  					break
   558  				}
   559  				if s[end] == '\\' {
   560  					end++
   561  				}
   562  			}
   563  			if end == len(s) {
   564  				return nil, errors.New("unterminated quoted string in connection info string")
   565  			}
   566  			val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
   567  			if end == len(s) {
   568  				s = ""
   569  			} else {
   570  				s = s[end+1:]
   571  			}
   572  		}
   573  
   574  		if k, ok := nameMap[key]; ok {
   575  			key = k
   576  		}
   577  
   578  		if key == "" {
   579  			return nil, errors.New("invalid dsn")
   580  		}
   581  
   582  		settings[key] = val
   583  	}
   584  
   585  	return settings, nil
   586  }
   587  
   588  func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
   589  	servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
   590  	if err != nil {
   591  		return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
   592  	}
   593  
   594  	service, err := servicefile.GetService(serviceName)
   595  	if err != nil {
   596  		return nil, fmt.Errorf("unable to find service: %v", serviceName)
   597  	}
   598  
   599  	nameMap := map[string]string{
   600  		"dbname": "database",
   601  	}
   602  
   603  	settings := make(map[string]string, len(service.Settings))
   604  	for k, v := range service.Settings {
   605  		if k2, present := nameMap[k]; present {
   606  			k = k2
   607  		}
   608  		settings[k] = v
   609  	}
   610  
   611  	return settings, nil
   612  }
   613  
   614  // configTLS uses libpq's TLS parameters to construct  []*tls.Config. It is
   615  // necessary to allow returning multiple TLS configs as sslmode "allow" and
   616  // "prefer" allow fallback.
   617  func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
   618  	host := thisHost
   619  	sslmode := settings["sslmode"]
   620  	sslrootcert := settings["sslrootcert"]
   621  	sslcert := settings["sslcert"]
   622  	sslkey := settings["sslkey"]
   623  	sslpassword := settings["sslpassword"]
   624  	sslsni := settings["sslsni"]
   625  
   626  	// Match libpq default behavior
   627  	if sslmode == "" {
   628  		sslmode = "prefer"
   629  	}
   630  	if sslsni == "" {
   631  		sslsni = "1"
   632  	}
   633  
   634  	tlsConfig := &tls.Config{}
   635  
   636  	switch sslmode {
   637  	case "disable":
   638  		return []*tls.Config{nil}, nil
   639  	case "allow", "prefer":
   640  		tlsConfig.InsecureSkipVerify = true
   641  	case "require":
   642  		// According to PostgreSQL documentation, if a root CA file exists,
   643  		// the behavior of sslmode=require should be the same as that of verify-ca
   644  		//
   645  		// See https://www.postgresql.org/docs/12/libpq-ssl.html
   646  		if sslrootcert != "" {
   647  			goto nextCase
   648  		}
   649  		tlsConfig.InsecureSkipVerify = true
   650  		break
   651  	nextCase:
   652  		fallthrough
   653  	case "verify-ca":
   654  		// Don't perform the default certificate verification because it
   655  		// will verify the hostname. Instead, verify the server's
   656  		// certificate chain ourselves in VerifyPeerCertificate and
   657  		// ignore the server name. This emulates libpq's verify-ca
   658  		// behavior.
   659  		//
   660  		// See https://github.com/golang/go/issues/21971#issuecomment-332693931
   661  		// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
   662  		// for more info.
   663  		tlsConfig.InsecureSkipVerify = true
   664  		tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
   665  			certs := make([]*x509.Certificate, len(certificates))
   666  			for i, asn1Data := range certificates {
   667  				cert, err := x509.ParseCertificate(asn1Data)
   668  				if err != nil {
   669  					return errors.New("failed to parse certificate from server: " + err.Error())
   670  				}
   671  				certs[i] = cert
   672  			}
   673  
   674  			// Leave DNSName empty to skip hostname verification.
   675  			opts := x509.VerifyOptions{
   676  				Roots:         tlsConfig.RootCAs,
   677  				Intermediates: x509.NewCertPool(),
   678  			}
   679  			// Skip the first cert because it's the leaf. All others
   680  			// are intermediates.
   681  			for _, cert := range certs[1:] {
   682  				opts.Intermediates.AddCert(cert)
   683  			}
   684  			_, err := certs[0].Verify(opts)
   685  			return err
   686  		}
   687  	case "verify-full":
   688  		tlsConfig.ServerName = host
   689  	default:
   690  		return nil, errors.New("sslmode is invalid")
   691  	}
   692  
   693  	if sslrootcert != "" {
   694  		caCertPool := x509.NewCertPool()
   695  
   696  		caPath := sslrootcert
   697  		caCert, err := ioutil.ReadFile(caPath)
   698  		if err != nil {
   699  			return nil, fmt.Errorf("unable to read CA file: %w", err)
   700  		}
   701  
   702  		if !caCertPool.AppendCertsFromPEM(caCert) {
   703  			return nil, errors.New("unable to add CA to cert pool")
   704  		}
   705  
   706  		tlsConfig.RootCAs = caCertPool
   707  		tlsConfig.ClientCAs = caCertPool
   708  	}
   709  
   710  	if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
   711  		return nil, errors.New(`both "sslcert" and "sslkey" are required`)
   712  	}
   713  
   714  	if sslcert != "" && sslkey != "" {
   715  		buf, err := ioutil.ReadFile(sslkey)
   716  		if err != nil {
   717  			return nil, fmt.Errorf("unable to read sslkey: %w", err)
   718  		}
   719  		block, _ := pem.Decode(buf)
   720  		var pemKey []byte
   721  		var decryptedKey []byte
   722  		var decryptedError error
   723  		// If PEM is encrypted, attempt to decrypt using pass phrase
   724  		if x509.IsEncryptedPEMBlock(block) {
   725  			// Attempt decryption with pass phrase
   726  			// NOTE: only supports RSA (PKCS#1)
   727  			if sslpassword != "" {
   728  				decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
   729  			}
   730  			//if sslpassword not provided or has decryption error when use it
   731  			//try to find sslpassword with callback function
   732  			if sslpassword == "" || decryptedError != nil {
   733  				if parseConfigOptions.GetSSLPassword != nil {
   734  					sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
   735  				}
   736  				if sslpassword == "" {
   737  					return nil, fmt.Errorf("unable to find sslpassword")
   738  				}
   739  			}
   740  			decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
   741  			// Should we also provide warning for PKCS#1 needed?
   742  			if decryptedError != nil {
   743  				return nil, fmt.Errorf("unable to decrypt key: %w", err)
   744  			}
   745  
   746  			pemBytes := pem.Block{
   747  				Type:  "RSA PRIVATE KEY",
   748  				Bytes: decryptedKey,
   749  			}
   750  			pemKey = pem.EncodeToMemory(&pemBytes)
   751  		} else {
   752  			pemKey = pem.EncodeToMemory(block)
   753  		}
   754  		certfile, err := ioutil.ReadFile(sslcert)
   755  		if err != nil {
   756  			return nil, fmt.Errorf("unable to read cert: %w", err)
   757  		}
   758  		cert, err := tls.X509KeyPair(certfile, pemKey)
   759  		if err != nil {
   760  			return nil, fmt.Errorf("unable to load cert: %w", err)
   761  		}
   762  		tlsConfig.Certificates = []tls.Certificate{cert}
   763  	}
   764  
   765  	// Set Server Name Indication (SNI), if enabled by connection parameters.
   766  	// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
   767  	// or IPv6).
   768  	if sslsni == "1" && net.ParseIP(host) == nil {
   769  		tlsConfig.ServerName = host
   770  	}
   771  
   772  	switch sslmode {
   773  	case "allow":
   774  		return []*tls.Config{nil, tlsConfig}, nil
   775  	case "prefer":
   776  		return []*tls.Config{tlsConfig, nil}, nil
   777  	case "require", "verify-ca", "verify-full":
   778  		return []*tls.Config{tlsConfig}, nil
   779  	default:
   780  		panic("BUG: bad sslmode should already have been caught")
   781  	}
   782  }
   783  
   784  func parsePort(s string) (uint16, error) {
   785  	port, err := strconv.ParseUint(s, 10, 16)
   786  	if err != nil {
   787  		return 0, err
   788  	}
   789  	if port < 1 || port > math.MaxUint16 {
   790  		return 0, errors.New("outside range")
   791  	}
   792  	return uint16(port), nil
   793  }
   794  
   795  func makeDefaultDialer() *net.Dialer {
   796  	return &net.Dialer{KeepAlive: 5 * time.Minute}
   797  }
   798  
   799  func makeDefaultResolver() *net.Resolver {
   800  	return net.DefaultResolver
   801  }
   802  
   803  func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
   804  	return func(r io.Reader, w io.Writer) Frontend {
   805  		cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
   806  		if err != nil {
   807  			panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
   808  		}
   809  		frontend := pgproto3.NewFrontend(cr, w)
   810  
   811  		return frontend
   812  	}
   813  }
   814  
   815  func parseConnectTimeoutSetting(s string) (time.Duration, error) {
   816  	timeout, err := strconv.ParseInt(s, 10, 64)
   817  	if err != nil {
   818  		return 0, err
   819  	}
   820  	if timeout < 0 {
   821  		return 0, errors.New("negative timeout")
   822  	}
   823  	return time.Duration(timeout) * time.Second, nil
   824  }
   825  
   826  func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
   827  	d := makeDefaultDialer()
   828  	d.Timeout = timeout
   829  	return d.DialContext
   830  }
   831  
   832  // ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
   833  // target_session_attrs=read-write.
   834  func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
   835  	result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
   836  	if result.Err != nil {
   837  		return result.Err
   838  	}
   839  
   840  	if string(result.Rows[0][0]) == "on" {
   841  		return errors.New("read only connection")
   842  	}
   843  
   844  	return nil
   845  }
   846  
   847  // ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
   848  // target_session_attrs=read-only.
   849  func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
   850  	result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
   851  	if result.Err != nil {
   852  		return result.Err
   853  	}
   854  
   855  	if string(result.Rows[0][0]) != "on" {
   856  		return errors.New("connection is not read only")
   857  	}
   858  
   859  	return nil
   860  }
   861  
   862  // ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
   863  // target_session_attrs=standby.
   864  func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
   865  	result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
   866  	if result.Err != nil {
   867  		return result.Err
   868  	}
   869  
   870  	if string(result.Rows[0][0]) != "t" {
   871  		return errors.New("server is not in hot standby mode")
   872  	}
   873  
   874  	return nil
   875  }
   876  
   877  // ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
   878  // target_session_attrs=primary.
   879  func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
   880  	result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
   881  	if result.Err != nil {
   882  		return result.Err
   883  	}
   884  
   885  	if string(result.Rows[0][0]) == "t" {
   886  		return errors.New("server is in standby mode")
   887  	}
   888  
   889  	return nil
   890  }
   891  
   892  // ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
   893  // target_session_attrs=prefer-standby.
   894  func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
   895  	result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
   896  	if result.Err != nil {
   897  		return result.Err
   898  	}
   899  
   900  	if string(result.Rows[0][0]) != "t" {
   901  		return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
   902  	}
   903  
   904  	return nil
   905  }
   906  

View as plain text