...

Source file src/github.com/jackc/pgx/v5/pgconn/config.go

Documentation: github.com/jackc/pgx/v5/pgconn

     1  package pgconn
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/pem"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"math"
    12  	"net"
    13  	"net/url"
    14  	"os"
    15  	"path/filepath"
    16  	"strconv"
    17  	"strings"
    18  	"time"
    19  
    20  	"github.com/jackc/pgpassfile"
    21  	"github.com/jackc/pgservicefile"
    22  	"github.com/jackc/pgx/v5/pgproto3"
    23  )
    24  
    25  type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
    26  type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
    27  type GetSSLPasswordFunc func(ctx context.Context) string
    28  
    29  // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
    30  // manually initialized Config will cause ConnectConfig to panic.
    31  type Config struct {
    32  	Host           string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
    33  	Port           uint16
    34  	Database       string
    35  	User           string
    36  	Password       string
    37  	TLSConfig      *tls.Config // nil disables TLS
    38  	ConnectTimeout time.Duration
    39  	DialFunc       DialFunc   // e.g. net.Dialer.DialContext
    40  	LookupFunc     LookupFunc // e.g. net.Resolver.LookupHost
    41  	BuildFrontend  BuildFrontendFunc
    42  	RuntimeParams  map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
    43  
    44  	KerberosSrvName string
    45  	KerberosSpn     string
    46  	Fallbacks       []*FallbackConfig
    47  
    48  	// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
    49  	// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
    50  	// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
    51  	ValidateConnect ValidateConnectFunc
    52  
    53  	// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
    54  	// or prepare statements). If this returns an error the connection attempt fails.
    55  	AfterConnect AfterConnectFunc
    56  
    57  	// OnNotice is a callback function called when a notice response is received.
    58  	OnNotice NoticeHandler
    59  
    60  	// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
    61  	OnNotification NotificationHandler
    62  
    63  	// OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close
    64  	// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
    65  	// that you close on FATAL errors by returning false.
    66  	OnPgError PgErrorHandler
    67  
    68  	createdByParseConfig bool // Used to enforce created by ParseConfig rule.
    69  }
    70  
    71  // ParseConfigOptions contains options that control how a config is built such as GetSSLPassword.
    72  type ParseConfigOptions struct {
    73  	// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
    74  	// PQsetSSLKeyPassHook_OpenSSL.
    75  	GetSSLPassword GetSSLPasswordFunc
    76  }
    77  
    78  // Copy returns a deep copy of the config that is safe to use and modify.
    79  // The only exception is the TLSConfig field:
    80  // according to the tls.Config docs it must not be modified after creation.
    81  func (c *Config) Copy() *Config {
    82  	newConf := new(Config)
    83  	*newConf = *c
    84  	if newConf.TLSConfig != nil {
    85  		newConf.TLSConfig = c.TLSConfig.Clone()
    86  	}
    87  	if newConf.RuntimeParams != nil {
    88  		newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
    89  		for k, v := range c.RuntimeParams {
    90  			newConf.RuntimeParams[k] = v
    91  		}
    92  	}
    93  	if newConf.Fallbacks != nil {
    94  		newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
    95  		for i, fallback := range c.Fallbacks {
    96  			newFallback := new(FallbackConfig)
    97  			*newFallback = *fallback
    98  			if newFallback.TLSConfig != nil {
    99  				newFallback.TLSConfig = fallback.TLSConfig.Clone()
   100  			}
   101  			newConf.Fallbacks[i] = newFallback
   102  		}
   103  	}
   104  	return newConf
   105  }
   106  
   107  // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
   108  // network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
   109  type FallbackConfig struct {
   110  	Host      string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
   111  	Port      uint16
   112  	TLSConfig *tls.Config // nil disables TLS
   113  }
   114  
   115  // isAbsolutePath checks if the provided value is an absolute path either
   116  // beginning with a forward slash (as on Linux-based systems) or with a capital
   117  // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
   118  func isAbsolutePath(path string) bool {
   119  	isWindowsPath := func(p string) bool {
   120  		if len(p) < 3 {
   121  			return false
   122  		}
   123  		drive := p[0]
   124  		colon := p[1]
   125  		backslash := p[2]
   126  		if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
   127  			return true
   128  		}
   129  		return false
   130  	}
   131  	return strings.HasPrefix(path, "/") || isWindowsPath(path)
   132  }
   133  
   134  // NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
   135  // net.Dial.
   136  func NetworkAddress(host string, port uint16) (network, address string) {
   137  	if isAbsolutePath(host) {
   138  		network = "unix"
   139  		address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
   140  	} else {
   141  		network = "tcp"
   142  		address = net.JoinHostPort(host, strconv.Itoa(int(port)))
   143  	}
   144  	return network, address
   145  }
   146  
   147  // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
   148  // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
   149  // matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
   150  // See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
   151  // empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
   152  //
   153  //	# Example DSN
   154  //	user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
   155  //
   156  //	# Example URL
   157  //	postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
   158  //
   159  // The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
   160  // through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
   161  // interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
   162  // not be modified individually. They should all be modified or all left unchanged.
   163  //
   164  // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
   165  // values that will be tried in order. This can be used as part of a high availability system. See
   166  // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
   167  //
   168  //	# Example URL
   169  //	postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
   170  //
   171  // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
   172  // via database URL or DSN:
   173  //
   174  //	PGHOST
   175  //	PGPORT
   176  //	PGDATABASE
   177  //	PGUSER
   178  //	PGPASSWORD
   179  //	PGPASSFILE
   180  //	PGSERVICE
   181  //	PGSERVICEFILE
   182  //	PGSSLMODE
   183  //	PGSSLCERT
   184  //	PGSSLKEY
   185  //	PGSSLROOTCERT
   186  //	PGSSLPASSWORD
   187  //	PGAPPNAME
   188  //	PGCONNECT_TIMEOUT
   189  //	PGTARGETSESSIONATTRS
   190  //
   191  // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
   192  //
   193  // See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
   194  // usually but not always the environment variable name downcased and without the "PG" prefix.
   195  //
   196  // Important Security Notes:
   197  //
   198  // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
   199  // not set.
   200  //
   201  // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
   202  // security each sslmode provides.
   203  //
   204  // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
   205  // the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
   206  // sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
   207  // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
   208  // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
   209  // TLSConfig.
   210  //
   211  // Other known differences with libpq:
   212  //
   213  // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
   214  // does not.
   215  //
   216  // In addition, ParseConfig accepts the following options:
   217  //
   218  //   - servicefile.
   219  //     libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
   220  //     part of the connection string.
   221  func ParseConfig(connString string) (*Config, error) {
   222  	var parseConfigOptions ParseConfigOptions
   223  	return ParseConfigWithOptions(connString, parseConfigOptions)
   224  }
   225  
   226  // ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
   227  // C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
   228  // get the SSL password.
   229  func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
   230  	defaultSettings := defaultSettings()
   231  	envSettings := parseEnvSettings()
   232  
   233  	connStringSettings := make(map[string]string)
   234  	if connString != "" {
   235  		var err error
   236  		// connString may be a database URL or a DSN
   237  		if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
   238  			connStringSettings, err = parseURLSettings(connString)
   239  			if err != nil {
   240  				return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
   241  			}
   242  		} else {
   243  			connStringSettings, err = parseDSNSettings(connString)
   244  			if err != nil {
   245  				return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err}
   246  			}
   247  		}
   248  	}
   249  
   250  	settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
   251  	if service, present := settings["service"]; present {
   252  		serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
   253  		if err != nil {
   254  			return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
   255  		}
   256  
   257  		settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
   258  	}
   259  
   260  	config := &Config{
   261  		createdByParseConfig: true,
   262  		Database:             settings["database"],
   263  		User:                 settings["user"],
   264  		Password:             settings["password"],
   265  		RuntimeParams:        make(map[string]string),
   266  		BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
   267  			return pgproto3.NewFrontend(r, w)
   268  		},
   269  		OnPgError: func(_ *PgConn, pgErr *PgError) bool {
   270  			// we want to automatically close any fatal errors
   271  			if strings.EqualFold(pgErr.Severity, "FATAL") {
   272  				return false
   273  			}
   274  			return true
   275  		},
   276  	}
   277  
   278  	if connectTimeoutSetting, present := settings["connect_timeout"]; present {
   279  		connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
   280  		if err != nil {
   281  			return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
   282  		}
   283  		config.ConnectTimeout = connectTimeout
   284  		config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
   285  	} else {
   286  		defaultDialer := makeDefaultDialer()
   287  		config.DialFunc = defaultDialer.DialContext
   288  	}
   289  
   290  	config.LookupFunc = makeDefaultResolver().LookupHost
   291  
   292  	notRuntimeParams := map[string]struct{}{
   293  		"host":                 {},
   294  		"port":                 {},
   295  		"database":             {},
   296  		"user":                 {},
   297  		"password":             {},
   298  		"passfile":             {},
   299  		"connect_timeout":      {},
   300  		"sslmode":              {},
   301  		"sslkey":               {},
   302  		"sslcert":              {},
   303  		"sslrootcert":          {},
   304  		"sslpassword":          {},
   305  		"sslsni":               {},
   306  		"krbspn":               {},
   307  		"krbsrvname":           {},
   308  		"target_session_attrs": {},
   309  		"service":              {},
   310  		"servicefile":          {},
   311  	}
   312  
   313  	// Adding kerberos configuration
   314  	if _, present := settings["krbsrvname"]; present {
   315  		config.KerberosSrvName = settings["krbsrvname"]
   316  	}
   317  	if _, present := settings["krbspn"]; present {
   318  		config.KerberosSpn = settings["krbspn"]
   319  	}
   320  
   321  	for k, v := range settings {
   322  		if _, present := notRuntimeParams[k]; present {
   323  			continue
   324  		}
   325  		config.RuntimeParams[k] = v
   326  	}
   327  
   328  	fallbacks := []*FallbackConfig{}
   329  
   330  	hosts := strings.Split(settings["host"], ",")
   331  	ports := strings.Split(settings["port"], ",")
   332  
   333  	for i, host := range hosts {
   334  		var portStr string
   335  		if i < len(ports) {
   336  			portStr = ports[i]
   337  		} else {
   338  			portStr = ports[0]
   339  		}
   340  
   341  		port, err := parsePort(portStr)
   342  		if err != nil {
   343  			return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
   344  		}
   345  
   346  		var tlsConfigs []*tls.Config
   347  
   348  		// Ignore TLS settings if Unix domain socket like libpq
   349  		if network, _ := NetworkAddress(host, port); network == "unix" {
   350  			tlsConfigs = append(tlsConfigs, nil)
   351  		} else {
   352  			var err error
   353  			tlsConfigs, err = configTLS(settings, host, options)
   354  			if err != nil {
   355  				return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
   356  			}
   357  		}
   358  
   359  		for _, tlsConfig := range tlsConfigs {
   360  			fallbacks = append(fallbacks, &FallbackConfig{
   361  				Host:      host,
   362  				Port:      port,
   363  				TLSConfig: tlsConfig,
   364  			})
   365  		}
   366  	}
   367  
   368  	config.Host = fallbacks[0].Host
   369  	config.Port = fallbacks[0].Port
   370  	config.TLSConfig = fallbacks[0].TLSConfig
   371  	config.Fallbacks = fallbacks[1:]
   372  
   373  	passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
   374  	if err == nil {
   375  		if config.Password == "" {
   376  			host := config.Host
   377  			if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
   378  				host = "localhost"
   379  			}
   380  
   381  			config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
   382  		}
   383  	}
   384  
   385  	switch tsa := settings["target_session_attrs"]; tsa {
   386  	case "read-write":
   387  		config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
   388  	case "read-only":
   389  		config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
   390  	case "primary":
   391  		config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
   392  	case "standby":
   393  		config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
   394  	case "prefer-standby":
   395  		config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
   396  	case "any":
   397  		// do nothing
   398  	default:
   399  		return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
   400  	}
   401  
   402  	return config, nil
   403  }
   404  
   405  func mergeSettings(settingSets ...map[string]string) map[string]string {
   406  	settings := make(map[string]string)
   407  
   408  	for _, s2 := range settingSets {
   409  		for k, v := range s2 {
   410  			settings[k] = v
   411  		}
   412  	}
   413  
   414  	return settings
   415  }
   416  
   417  func parseEnvSettings() map[string]string {
   418  	settings := make(map[string]string)
   419  
   420  	nameMap := map[string]string{
   421  		"PGHOST":               "host",
   422  		"PGPORT":               "port",
   423  		"PGDATABASE":           "database",
   424  		"PGUSER":               "user",
   425  		"PGPASSWORD":           "password",
   426  		"PGPASSFILE":           "passfile",
   427  		"PGAPPNAME":            "application_name",
   428  		"PGCONNECT_TIMEOUT":    "connect_timeout",
   429  		"PGSSLMODE":            "sslmode",
   430  		"PGSSLKEY":             "sslkey",
   431  		"PGSSLCERT":            "sslcert",
   432  		"PGSSLSNI":             "sslsni",
   433  		"PGSSLROOTCERT":        "sslrootcert",
   434  		"PGSSLPASSWORD":        "sslpassword",
   435  		"PGTARGETSESSIONATTRS": "target_session_attrs",
   436  		"PGSERVICE":            "service",
   437  		"PGSERVICEFILE":        "servicefile",
   438  	}
   439  
   440  	for envname, realname := range nameMap {
   441  		value := os.Getenv(envname)
   442  		if value != "" {
   443  			settings[realname] = value
   444  		}
   445  	}
   446  
   447  	return settings
   448  }
   449  
   450  func parseURLSettings(connString string) (map[string]string, error) {
   451  	settings := make(map[string]string)
   452  
   453  	url, err := url.Parse(connString)
   454  	if err != nil {
   455  		return nil, err
   456  	}
   457  
   458  	if url.User != nil {
   459  		settings["user"] = url.User.Username()
   460  		if password, present := url.User.Password(); present {
   461  			settings["password"] = password
   462  		}
   463  	}
   464  
   465  	// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
   466  	var hosts []string
   467  	var ports []string
   468  	for _, host := range strings.Split(url.Host, ",") {
   469  		if host == "" {
   470  			continue
   471  		}
   472  		if isIPOnly(host) {
   473  			hosts = append(hosts, strings.Trim(host, "[]"))
   474  			continue
   475  		}
   476  		h, p, err := net.SplitHostPort(host)
   477  		if err != nil {
   478  			return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
   479  		}
   480  		if h != "" {
   481  			hosts = append(hosts, h)
   482  		}
   483  		if p != "" {
   484  			ports = append(ports, p)
   485  		}
   486  	}
   487  	if len(hosts) > 0 {
   488  		settings["host"] = strings.Join(hosts, ",")
   489  	}
   490  	if len(ports) > 0 {
   491  		settings["port"] = strings.Join(ports, ",")
   492  	}
   493  
   494  	database := strings.TrimLeft(url.Path, "/")
   495  	if database != "" {
   496  		settings["database"] = database
   497  	}
   498  
   499  	nameMap := map[string]string{
   500  		"dbname": "database",
   501  	}
   502  
   503  	for k, v := range url.Query() {
   504  		if k2, present := nameMap[k]; present {
   505  			k = k2
   506  		}
   507  
   508  		settings[k] = v[0]
   509  	}
   510  
   511  	return settings, nil
   512  }
   513  
   514  func isIPOnly(host string) bool {
   515  	return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
   516  }
   517  
   518  var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
   519  
   520  func parseDSNSettings(s string) (map[string]string, error) {
   521  	settings := make(map[string]string)
   522  
   523  	nameMap := map[string]string{
   524  		"dbname": "database",
   525  	}
   526  
   527  	for len(s) > 0 {
   528  		var key, val string
   529  		eqIdx := strings.IndexRune(s, '=')
   530  		if eqIdx < 0 {
   531  			return nil, errors.New("invalid dsn")
   532  		}
   533  
   534  		key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
   535  		s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
   536  		if len(s) == 0 {
   537  		} else if s[0] != '\'' {
   538  			end := 0
   539  			for ; end < len(s); end++ {
   540  				if asciiSpace[s[end]] == 1 {
   541  					break
   542  				}
   543  				if s[end] == '\\' {
   544  					end++
   545  					if end == len(s) {
   546  						return nil, errors.New("invalid backslash")
   547  					}
   548  				}
   549  			}
   550  			val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
   551  			if end == len(s) {
   552  				s = ""
   553  			} else {
   554  				s = s[end+1:]
   555  			}
   556  		} else { // quoted string
   557  			s = s[1:]
   558  			end := 0
   559  			for ; end < len(s); end++ {
   560  				if s[end] == '\'' {
   561  					break
   562  				}
   563  				if s[end] == '\\' {
   564  					end++
   565  				}
   566  			}
   567  			if end == len(s) {
   568  				return nil, errors.New("unterminated quoted string in connection info string")
   569  			}
   570  			val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
   571  			if end == len(s) {
   572  				s = ""
   573  			} else {
   574  				s = s[end+1:]
   575  			}
   576  		}
   577  
   578  		if k, ok := nameMap[key]; ok {
   579  			key = k
   580  		}
   581  
   582  		if key == "" {
   583  			return nil, errors.New("invalid dsn")
   584  		}
   585  
   586  		settings[key] = val
   587  	}
   588  
   589  	return settings, nil
   590  }
   591  
   592  func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
   593  	servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
   594  	if err != nil {
   595  		return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
   596  	}
   597  
   598  	service, err := servicefile.GetService(serviceName)
   599  	if err != nil {
   600  		return nil, fmt.Errorf("unable to find service: %v", serviceName)
   601  	}
   602  
   603  	nameMap := map[string]string{
   604  		"dbname": "database",
   605  	}
   606  
   607  	settings := make(map[string]string, len(service.Settings))
   608  	for k, v := range service.Settings {
   609  		if k2, present := nameMap[k]; present {
   610  			k = k2
   611  		}
   612  		settings[k] = v
   613  	}
   614  
   615  	return settings, nil
   616  }
   617  
   618  // configTLS uses libpq's TLS parameters to construct  []*tls.Config. It is
   619  // necessary to allow returning multiple TLS configs as sslmode "allow" and
   620  // "prefer" allow fallback.
   621  func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
   622  	host := thisHost
   623  	sslmode := settings["sslmode"]
   624  	sslrootcert := settings["sslrootcert"]
   625  	sslcert := settings["sslcert"]
   626  	sslkey := settings["sslkey"]
   627  	sslpassword := settings["sslpassword"]
   628  	sslsni := settings["sslsni"]
   629  
   630  	// Match libpq default behavior
   631  	if sslmode == "" {
   632  		sslmode = "prefer"
   633  	}
   634  	if sslsni == "" {
   635  		sslsni = "1"
   636  	}
   637  
   638  	tlsConfig := &tls.Config{}
   639  
   640  	switch sslmode {
   641  	case "disable":
   642  		return []*tls.Config{nil}, nil
   643  	case "allow", "prefer":
   644  		tlsConfig.InsecureSkipVerify = true
   645  	case "require":
   646  		// According to PostgreSQL documentation, if a root CA file exists,
   647  		// the behavior of sslmode=require should be the same as that of verify-ca
   648  		//
   649  		// See https://www.postgresql.org/docs/12/libpq-ssl.html
   650  		if sslrootcert != "" {
   651  			goto nextCase
   652  		}
   653  		tlsConfig.InsecureSkipVerify = true
   654  		break
   655  	nextCase:
   656  		fallthrough
   657  	case "verify-ca":
   658  		// Don't perform the default certificate verification because it
   659  		// will verify the hostname. Instead, verify the server's
   660  		// certificate chain ourselves in VerifyPeerCertificate and
   661  		// ignore the server name. This emulates libpq's verify-ca
   662  		// behavior.
   663  		//
   664  		// See https://github.com/golang/go/issues/21971#issuecomment-332693931
   665  		// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
   666  		// for more info.
   667  		tlsConfig.InsecureSkipVerify = true
   668  		tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
   669  			certs := make([]*x509.Certificate, len(certificates))
   670  			for i, asn1Data := range certificates {
   671  				cert, err := x509.ParseCertificate(asn1Data)
   672  				if err != nil {
   673  					return errors.New("failed to parse certificate from server: " + err.Error())
   674  				}
   675  				certs[i] = cert
   676  			}
   677  
   678  			// Leave DNSName empty to skip hostname verification.
   679  			opts := x509.VerifyOptions{
   680  				Roots:         tlsConfig.RootCAs,
   681  				Intermediates: x509.NewCertPool(),
   682  			}
   683  			// Skip the first cert because it's the leaf. All others
   684  			// are intermediates.
   685  			for _, cert := range certs[1:] {
   686  				opts.Intermediates.AddCert(cert)
   687  			}
   688  			_, err := certs[0].Verify(opts)
   689  			return err
   690  		}
   691  	case "verify-full":
   692  		tlsConfig.ServerName = host
   693  	default:
   694  		return nil, errors.New("sslmode is invalid")
   695  	}
   696  
   697  	if sslrootcert != "" {
   698  		caCertPool := x509.NewCertPool()
   699  
   700  		caPath := sslrootcert
   701  		caCert, err := os.ReadFile(caPath)
   702  		if err != nil {
   703  			return nil, fmt.Errorf("unable to read CA file: %w", err)
   704  		}
   705  
   706  		if !caCertPool.AppendCertsFromPEM(caCert) {
   707  			return nil, errors.New("unable to add CA to cert pool")
   708  		}
   709  
   710  		tlsConfig.RootCAs = caCertPool
   711  		tlsConfig.ClientCAs = caCertPool
   712  	}
   713  
   714  	if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
   715  		return nil, errors.New(`both "sslcert" and "sslkey" are required`)
   716  	}
   717  
   718  	if sslcert != "" && sslkey != "" {
   719  		buf, err := os.ReadFile(sslkey)
   720  		if err != nil {
   721  			return nil, fmt.Errorf("unable to read sslkey: %w", err)
   722  		}
   723  		block, _ := pem.Decode(buf)
   724  		if block == nil {
   725  			return nil, errors.New("failed to decode sslkey")
   726  		}
   727  		var pemKey []byte
   728  		var decryptedKey []byte
   729  		var decryptedError error
   730  		// If PEM is encrypted, attempt to decrypt using pass phrase
   731  		if x509.IsEncryptedPEMBlock(block) {
   732  			// Attempt decryption with pass phrase
   733  			// NOTE: only supports RSA (PKCS#1)
   734  			if sslpassword != "" {
   735  				decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
   736  			}
   737  			//if sslpassword not provided or has decryption error when use it
   738  			//try to find sslpassword with callback function
   739  			if sslpassword == "" || decryptedError != nil {
   740  				if parseConfigOptions.GetSSLPassword != nil {
   741  					sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
   742  				}
   743  				if sslpassword == "" {
   744  					return nil, fmt.Errorf("unable to find sslpassword")
   745  				}
   746  			}
   747  			decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
   748  			// Should we also provide warning for PKCS#1 needed?
   749  			if decryptedError != nil {
   750  				return nil, fmt.Errorf("unable to decrypt key: %w", err)
   751  			}
   752  
   753  			pemBytes := pem.Block{
   754  				Type:  "RSA PRIVATE KEY",
   755  				Bytes: decryptedKey,
   756  			}
   757  			pemKey = pem.EncodeToMemory(&pemBytes)
   758  		} else {
   759  			pemKey = pem.EncodeToMemory(block)
   760  		}
   761  		certfile, err := os.ReadFile(sslcert)
   762  		if err != nil {
   763  			return nil, fmt.Errorf("unable to read cert: %w", err)
   764  		}
   765  		cert, err := tls.X509KeyPair(certfile, pemKey)
   766  		if err != nil {
   767  			return nil, fmt.Errorf("unable to load cert: %w", err)
   768  		}
   769  		tlsConfig.Certificates = []tls.Certificate{cert}
   770  	}
   771  
   772  	// Set Server Name Indication (SNI), if enabled by connection parameters.
   773  	// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
   774  	// or IPv6).
   775  	if sslsni == "1" && net.ParseIP(host) == nil {
   776  		tlsConfig.ServerName = host
   777  	}
   778  
   779  	switch sslmode {
   780  	case "allow":
   781  		return []*tls.Config{nil, tlsConfig}, nil
   782  	case "prefer":
   783  		return []*tls.Config{tlsConfig, nil}, nil
   784  	case "require", "verify-ca", "verify-full":
   785  		return []*tls.Config{tlsConfig}, nil
   786  	default:
   787  		panic("BUG: bad sslmode should already have been caught")
   788  	}
   789  }
   790  
   791  func parsePort(s string) (uint16, error) {
   792  	port, err := strconv.ParseUint(s, 10, 16)
   793  	if err != nil {
   794  		return 0, err
   795  	}
   796  	if port < 1 || port > math.MaxUint16 {
   797  		return 0, errors.New("outside range")
   798  	}
   799  	return uint16(port), nil
   800  }
   801  
   802  func makeDefaultDialer() *net.Dialer {
   803  	return &net.Dialer{KeepAlive: 5 * time.Minute}
   804  }
   805  
   806  func makeDefaultResolver() *net.Resolver {
   807  	return net.DefaultResolver
   808  }
   809  
   810  func parseConnectTimeoutSetting(s string) (time.Duration, error) {
   811  	timeout, err := strconv.ParseInt(s, 10, 64)
   812  	if err != nil {
   813  		return 0, err
   814  	}
   815  	if timeout < 0 {
   816  		return 0, errors.New("negative timeout")
   817  	}
   818  	return time.Duration(timeout) * time.Second, nil
   819  }
   820  
   821  func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
   822  	d := makeDefaultDialer()
   823  	d.Timeout = timeout
   824  	return d.DialContext
   825  }
   826  
   827  // ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible
   828  // target_session_attrs=read-write.
   829  func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
   830  	result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
   831  	if result.Err != nil {
   832  		return result.Err
   833  	}
   834  
   835  	if string(result.Rows[0][0]) == "on" {
   836  		return errors.New("read only connection")
   837  	}
   838  
   839  	return nil
   840  }
   841  
   842  // ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible
   843  // target_session_attrs=read-only.
   844  func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
   845  	result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
   846  	if result.Err != nil {
   847  		return result.Err
   848  	}
   849  
   850  	if string(result.Rows[0][0]) != "on" {
   851  		return errors.New("connection is not read only")
   852  	}
   853  
   854  	return nil
   855  }
   856  
   857  // ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible
   858  // target_session_attrs=standby.
   859  func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
   860  	result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
   861  	if result.Err != nil {
   862  		return result.Err
   863  	}
   864  
   865  	if string(result.Rows[0][0]) != "t" {
   866  		return errors.New("server is not in hot standby mode")
   867  	}
   868  
   869  	return nil
   870  }
   871  
   872  // ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible
   873  // target_session_attrs=primary.
   874  func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
   875  	result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
   876  	if result.Err != nil {
   877  		return result.Err
   878  	}
   879  
   880  	if string(result.Rows[0][0]) == "t" {
   881  		return errors.New("server is in standby mode")
   882  	}
   883  
   884  	return nil
   885  }
   886  
   887  // ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible
   888  // target_session_attrs=prefer-standby.
   889  func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
   890  	result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
   891  	if result.Err != nil {
   892  		return result.Err
   893  	}
   894  
   895  	if string(result.Rows[0][0]) != "t" {
   896  		return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
   897  	}
   898  
   899  	return nil
   900  }
   901  

View as plain text