1 package pgconn
2
3 import (
4 "context"
5 "crypto/tls"
6 "crypto/x509"
7 "encoding/pem"
8 "errors"
9 "fmt"
10 "io"
11 "math"
12 "net"
13 "net/url"
14 "os"
15 "path/filepath"
16 "strconv"
17 "strings"
18 "time"
19
20 "github.com/jackc/pgpassfile"
21 "github.com/jackc/pgservicefile"
22 "github.com/jackc/pgx/v5/pgproto3"
23 )
24
25 type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
26 type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
27 type GetSSLPasswordFunc func(ctx context.Context) string
28
29
30
31 type Config struct {
32 Host string
33 Port uint16
34 Database string
35 User string
36 Password string
37 TLSConfig *tls.Config
38 ConnectTimeout time.Duration
39 DialFunc DialFunc
40 LookupFunc LookupFunc
41 BuildFrontend BuildFrontendFunc
42 RuntimeParams map[string]string
43
44 KerberosSrvName string
45 KerberosSpn string
46 Fallbacks []*FallbackConfig
47
48
49
50
51 ValidateConnect ValidateConnectFunc
52
53
54
55 AfterConnect AfterConnectFunc
56
57
58 OnNotice NoticeHandler
59
60
61 OnNotification NotificationHandler
62
63
64
65
66 OnPgError PgErrorHandler
67
68 createdByParseConfig bool
69 }
70
71
72 type ParseConfigOptions struct {
73
74
75 GetSSLPassword GetSSLPasswordFunc
76 }
77
78
79
80
81 func (c *Config) Copy() *Config {
82 newConf := new(Config)
83 *newConf = *c
84 if newConf.TLSConfig != nil {
85 newConf.TLSConfig = c.TLSConfig.Clone()
86 }
87 if newConf.RuntimeParams != nil {
88 newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
89 for k, v := range c.RuntimeParams {
90 newConf.RuntimeParams[k] = v
91 }
92 }
93 if newConf.Fallbacks != nil {
94 newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
95 for i, fallback := range c.Fallbacks {
96 newFallback := new(FallbackConfig)
97 *newFallback = *fallback
98 if newFallback.TLSConfig != nil {
99 newFallback.TLSConfig = fallback.TLSConfig.Clone()
100 }
101 newConf.Fallbacks[i] = newFallback
102 }
103 }
104 return newConf
105 }
106
107
108
109 type FallbackConfig struct {
110 Host string
111 Port uint16
112 TLSConfig *tls.Config
113 }
114
115
116
117
118 func isAbsolutePath(path string) bool {
119 isWindowsPath := func(p string) bool {
120 if len(p) < 3 {
121 return false
122 }
123 drive := p[0]
124 colon := p[1]
125 backslash := p[2]
126 if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
127 return true
128 }
129 return false
130 }
131 return strings.HasPrefix(path, "/") || isWindowsPath(path)
132 }
133
134
135
136 func NetworkAddress(host string, port uint16) (network, address string) {
137 if isAbsolutePath(host) {
138 network = "unix"
139 address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
140 } else {
141 network = "tcp"
142 address = net.JoinHostPort(host, strconv.Itoa(int(port)))
143 }
144 return network, address
145 }
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221 func ParseConfig(connString string) (*Config, error) {
222 var parseConfigOptions ParseConfigOptions
223 return ParseConfigWithOptions(connString, parseConfigOptions)
224 }
225
226
227
228
229 func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
230 defaultSettings := defaultSettings()
231 envSettings := parseEnvSettings()
232
233 connStringSettings := make(map[string]string)
234 if connString != "" {
235 var err error
236
237 if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
238 connStringSettings, err = parseURLSettings(connString)
239 if err != nil {
240 return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
241 }
242 } else {
243 connStringSettings, err = parseDSNSettings(connString)
244 if err != nil {
245 return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err}
246 }
247 }
248 }
249
250 settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
251 if service, present := settings["service"]; present {
252 serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
253 if err != nil {
254 return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
255 }
256
257 settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
258 }
259
260 config := &Config{
261 createdByParseConfig: true,
262 Database: settings["database"],
263 User: settings["user"],
264 Password: settings["password"],
265 RuntimeParams: make(map[string]string),
266 BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
267 return pgproto3.NewFrontend(r, w)
268 },
269 OnPgError: func(_ *PgConn, pgErr *PgError) bool {
270
271 if strings.EqualFold(pgErr.Severity, "FATAL") {
272 return false
273 }
274 return true
275 },
276 }
277
278 if connectTimeoutSetting, present := settings["connect_timeout"]; present {
279 connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
280 if err != nil {
281 return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
282 }
283 config.ConnectTimeout = connectTimeout
284 config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
285 } else {
286 defaultDialer := makeDefaultDialer()
287 config.DialFunc = defaultDialer.DialContext
288 }
289
290 config.LookupFunc = makeDefaultResolver().LookupHost
291
292 notRuntimeParams := map[string]struct{}{
293 "host": {},
294 "port": {},
295 "database": {},
296 "user": {},
297 "password": {},
298 "passfile": {},
299 "connect_timeout": {},
300 "sslmode": {},
301 "sslkey": {},
302 "sslcert": {},
303 "sslrootcert": {},
304 "sslpassword": {},
305 "sslsni": {},
306 "krbspn": {},
307 "krbsrvname": {},
308 "target_session_attrs": {},
309 "service": {},
310 "servicefile": {},
311 }
312
313
314 if _, present := settings["krbsrvname"]; present {
315 config.KerberosSrvName = settings["krbsrvname"]
316 }
317 if _, present := settings["krbspn"]; present {
318 config.KerberosSpn = settings["krbspn"]
319 }
320
321 for k, v := range settings {
322 if _, present := notRuntimeParams[k]; present {
323 continue
324 }
325 config.RuntimeParams[k] = v
326 }
327
328 fallbacks := []*FallbackConfig{}
329
330 hosts := strings.Split(settings["host"], ",")
331 ports := strings.Split(settings["port"], ",")
332
333 for i, host := range hosts {
334 var portStr string
335 if i < len(ports) {
336 portStr = ports[i]
337 } else {
338 portStr = ports[0]
339 }
340
341 port, err := parsePort(portStr)
342 if err != nil {
343 return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
344 }
345
346 var tlsConfigs []*tls.Config
347
348
349 if network, _ := NetworkAddress(host, port); network == "unix" {
350 tlsConfigs = append(tlsConfigs, nil)
351 } else {
352 var err error
353 tlsConfigs, err = configTLS(settings, host, options)
354 if err != nil {
355 return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
356 }
357 }
358
359 for _, tlsConfig := range tlsConfigs {
360 fallbacks = append(fallbacks, &FallbackConfig{
361 Host: host,
362 Port: port,
363 TLSConfig: tlsConfig,
364 })
365 }
366 }
367
368 config.Host = fallbacks[0].Host
369 config.Port = fallbacks[0].Port
370 config.TLSConfig = fallbacks[0].TLSConfig
371 config.Fallbacks = fallbacks[1:]
372
373 passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
374 if err == nil {
375 if config.Password == "" {
376 host := config.Host
377 if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
378 host = "localhost"
379 }
380
381 config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
382 }
383 }
384
385 switch tsa := settings["target_session_attrs"]; tsa {
386 case "read-write":
387 config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
388 case "read-only":
389 config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
390 case "primary":
391 config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
392 case "standby":
393 config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
394 case "prefer-standby":
395 config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
396 case "any":
397
398 default:
399 return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
400 }
401
402 return config, nil
403 }
404
405 func mergeSettings(settingSets ...map[string]string) map[string]string {
406 settings := make(map[string]string)
407
408 for _, s2 := range settingSets {
409 for k, v := range s2 {
410 settings[k] = v
411 }
412 }
413
414 return settings
415 }
416
417 func parseEnvSettings() map[string]string {
418 settings := make(map[string]string)
419
420 nameMap := map[string]string{
421 "PGHOST": "host",
422 "PGPORT": "port",
423 "PGDATABASE": "database",
424 "PGUSER": "user",
425 "PGPASSWORD": "password",
426 "PGPASSFILE": "passfile",
427 "PGAPPNAME": "application_name",
428 "PGCONNECT_TIMEOUT": "connect_timeout",
429 "PGSSLMODE": "sslmode",
430 "PGSSLKEY": "sslkey",
431 "PGSSLCERT": "sslcert",
432 "PGSSLSNI": "sslsni",
433 "PGSSLROOTCERT": "sslrootcert",
434 "PGSSLPASSWORD": "sslpassword",
435 "PGTARGETSESSIONATTRS": "target_session_attrs",
436 "PGSERVICE": "service",
437 "PGSERVICEFILE": "servicefile",
438 }
439
440 for envname, realname := range nameMap {
441 value := os.Getenv(envname)
442 if value != "" {
443 settings[realname] = value
444 }
445 }
446
447 return settings
448 }
449
450 func parseURLSettings(connString string) (map[string]string, error) {
451 settings := make(map[string]string)
452
453 url, err := url.Parse(connString)
454 if err != nil {
455 return nil, err
456 }
457
458 if url.User != nil {
459 settings["user"] = url.User.Username()
460 if password, present := url.User.Password(); present {
461 settings["password"] = password
462 }
463 }
464
465
466 var hosts []string
467 var ports []string
468 for _, host := range strings.Split(url.Host, ",") {
469 if host == "" {
470 continue
471 }
472 if isIPOnly(host) {
473 hosts = append(hosts, strings.Trim(host, "[]"))
474 continue
475 }
476 h, p, err := net.SplitHostPort(host)
477 if err != nil {
478 return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
479 }
480 if h != "" {
481 hosts = append(hosts, h)
482 }
483 if p != "" {
484 ports = append(ports, p)
485 }
486 }
487 if len(hosts) > 0 {
488 settings["host"] = strings.Join(hosts, ",")
489 }
490 if len(ports) > 0 {
491 settings["port"] = strings.Join(ports, ",")
492 }
493
494 database := strings.TrimLeft(url.Path, "/")
495 if database != "" {
496 settings["database"] = database
497 }
498
499 nameMap := map[string]string{
500 "dbname": "database",
501 }
502
503 for k, v := range url.Query() {
504 if k2, present := nameMap[k]; present {
505 k = k2
506 }
507
508 settings[k] = v[0]
509 }
510
511 return settings, nil
512 }
513
514 func isIPOnly(host string) bool {
515 return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
516 }
517
518 var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
519
520 func parseDSNSettings(s string) (map[string]string, error) {
521 settings := make(map[string]string)
522
523 nameMap := map[string]string{
524 "dbname": "database",
525 }
526
527 for len(s) > 0 {
528 var key, val string
529 eqIdx := strings.IndexRune(s, '=')
530 if eqIdx < 0 {
531 return nil, errors.New("invalid dsn")
532 }
533
534 key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
535 s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
536 if len(s) == 0 {
537 } else if s[0] != '\'' {
538 end := 0
539 for ; end < len(s); end++ {
540 if asciiSpace[s[end]] == 1 {
541 break
542 }
543 if s[end] == '\\' {
544 end++
545 if end == len(s) {
546 return nil, errors.New("invalid backslash")
547 }
548 }
549 }
550 val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
551 if end == len(s) {
552 s = ""
553 } else {
554 s = s[end+1:]
555 }
556 } else {
557 s = s[1:]
558 end := 0
559 for ; end < len(s); end++ {
560 if s[end] == '\'' {
561 break
562 }
563 if s[end] == '\\' {
564 end++
565 }
566 }
567 if end == len(s) {
568 return nil, errors.New("unterminated quoted string in connection info string")
569 }
570 val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
571 if end == len(s) {
572 s = ""
573 } else {
574 s = s[end+1:]
575 }
576 }
577
578 if k, ok := nameMap[key]; ok {
579 key = k
580 }
581
582 if key == "" {
583 return nil, errors.New("invalid dsn")
584 }
585
586 settings[key] = val
587 }
588
589 return settings, nil
590 }
591
592 func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
593 servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
594 if err != nil {
595 return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
596 }
597
598 service, err := servicefile.GetService(serviceName)
599 if err != nil {
600 return nil, fmt.Errorf("unable to find service: %v", serviceName)
601 }
602
603 nameMap := map[string]string{
604 "dbname": "database",
605 }
606
607 settings := make(map[string]string, len(service.Settings))
608 for k, v := range service.Settings {
609 if k2, present := nameMap[k]; present {
610 k = k2
611 }
612 settings[k] = v
613 }
614
615 return settings, nil
616 }
617
618
619
620
621 func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
622 host := thisHost
623 sslmode := settings["sslmode"]
624 sslrootcert := settings["sslrootcert"]
625 sslcert := settings["sslcert"]
626 sslkey := settings["sslkey"]
627 sslpassword := settings["sslpassword"]
628 sslsni := settings["sslsni"]
629
630
631 if sslmode == "" {
632 sslmode = "prefer"
633 }
634 if sslsni == "" {
635 sslsni = "1"
636 }
637
638 tlsConfig := &tls.Config{}
639
640 switch sslmode {
641 case "disable":
642 return []*tls.Config{nil}, nil
643 case "allow", "prefer":
644 tlsConfig.InsecureSkipVerify = true
645 case "require":
646
647
648
649
650 if sslrootcert != "" {
651 goto nextCase
652 }
653 tlsConfig.InsecureSkipVerify = true
654 break
655 nextCase:
656 fallthrough
657 case "verify-ca":
658
659
660
661
662
663
664
665
666
667 tlsConfig.InsecureSkipVerify = true
668 tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
669 certs := make([]*x509.Certificate, len(certificates))
670 for i, asn1Data := range certificates {
671 cert, err := x509.ParseCertificate(asn1Data)
672 if err != nil {
673 return errors.New("failed to parse certificate from server: " + err.Error())
674 }
675 certs[i] = cert
676 }
677
678
679 opts := x509.VerifyOptions{
680 Roots: tlsConfig.RootCAs,
681 Intermediates: x509.NewCertPool(),
682 }
683
684
685 for _, cert := range certs[1:] {
686 opts.Intermediates.AddCert(cert)
687 }
688 _, err := certs[0].Verify(opts)
689 return err
690 }
691 case "verify-full":
692 tlsConfig.ServerName = host
693 default:
694 return nil, errors.New("sslmode is invalid")
695 }
696
697 if sslrootcert != "" {
698 caCertPool := x509.NewCertPool()
699
700 caPath := sslrootcert
701 caCert, err := os.ReadFile(caPath)
702 if err != nil {
703 return nil, fmt.Errorf("unable to read CA file: %w", err)
704 }
705
706 if !caCertPool.AppendCertsFromPEM(caCert) {
707 return nil, errors.New("unable to add CA to cert pool")
708 }
709
710 tlsConfig.RootCAs = caCertPool
711 tlsConfig.ClientCAs = caCertPool
712 }
713
714 if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
715 return nil, errors.New(`both "sslcert" and "sslkey" are required`)
716 }
717
718 if sslcert != "" && sslkey != "" {
719 buf, err := os.ReadFile(sslkey)
720 if err != nil {
721 return nil, fmt.Errorf("unable to read sslkey: %w", err)
722 }
723 block, _ := pem.Decode(buf)
724 if block == nil {
725 return nil, errors.New("failed to decode sslkey")
726 }
727 var pemKey []byte
728 var decryptedKey []byte
729 var decryptedError error
730
731 if x509.IsEncryptedPEMBlock(block) {
732
733
734 if sslpassword != "" {
735 decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
736 }
737
738
739 if sslpassword == "" || decryptedError != nil {
740 if parseConfigOptions.GetSSLPassword != nil {
741 sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
742 }
743 if sslpassword == "" {
744 return nil, fmt.Errorf("unable to find sslpassword")
745 }
746 }
747 decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
748
749 if decryptedError != nil {
750 return nil, fmt.Errorf("unable to decrypt key: %w", err)
751 }
752
753 pemBytes := pem.Block{
754 Type: "RSA PRIVATE KEY",
755 Bytes: decryptedKey,
756 }
757 pemKey = pem.EncodeToMemory(&pemBytes)
758 } else {
759 pemKey = pem.EncodeToMemory(block)
760 }
761 certfile, err := os.ReadFile(sslcert)
762 if err != nil {
763 return nil, fmt.Errorf("unable to read cert: %w", err)
764 }
765 cert, err := tls.X509KeyPair(certfile, pemKey)
766 if err != nil {
767 return nil, fmt.Errorf("unable to load cert: %w", err)
768 }
769 tlsConfig.Certificates = []tls.Certificate{cert}
770 }
771
772
773
774
775 if sslsni == "1" && net.ParseIP(host) == nil {
776 tlsConfig.ServerName = host
777 }
778
779 switch sslmode {
780 case "allow":
781 return []*tls.Config{nil, tlsConfig}, nil
782 case "prefer":
783 return []*tls.Config{tlsConfig, nil}, nil
784 case "require", "verify-ca", "verify-full":
785 return []*tls.Config{tlsConfig}, nil
786 default:
787 panic("BUG: bad sslmode should already have been caught")
788 }
789 }
790
791 func parsePort(s string) (uint16, error) {
792 port, err := strconv.ParseUint(s, 10, 16)
793 if err != nil {
794 return 0, err
795 }
796 if port < 1 || port > math.MaxUint16 {
797 return 0, errors.New("outside range")
798 }
799 return uint16(port), nil
800 }
801
802 func makeDefaultDialer() *net.Dialer {
803 return &net.Dialer{KeepAlive: 5 * time.Minute}
804 }
805
806 func makeDefaultResolver() *net.Resolver {
807 return net.DefaultResolver
808 }
809
810 func parseConnectTimeoutSetting(s string) (time.Duration, error) {
811 timeout, err := strconv.ParseInt(s, 10, 64)
812 if err != nil {
813 return 0, err
814 }
815 if timeout < 0 {
816 return 0, errors.New("negative timeout")
817 }
818 return time.Duration(timeout) * time.Second, nil
819 }
820
821 func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
822 d := makeDefaultDialer()
823 d.Timeout = timeout
824 return d.DialContext
825 }
826
827
828
829 func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
830 result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
831 if result.Err != nil {
832 return result.Err
833 }
834
835 if string(result.Rows[0][0]) == "on" {
836 return errors.New("read only connection")
837 }
838
839 return nil
840 }
841
842
843
844 func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
845 result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
846 if result.Err != nil {
847 return result.Err
848 }
849
850 if string(result.Rows[0][0]) != "on" {
851 return errors.New("connection is not read only")
852 }
853
854 return nil
855 }
856
857
858
859 func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
860 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
861 if result.Err != nil {
862 return result.Err
863 }
864
865 if string(result.Rows[0][0]) != "t" {
866 return errors.New("server is not in hot standby mode")
867 }
868
869 return nil
870 }
871
872
873
874 func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
875 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
876 if result.Err != nil {
877 return result.Err
878 }
879
880 if string(result.Rows[0][0]) == "t" {
881 return errors.New("server is in standby mode")
882 }
883
884 return nil
885 }
886
887
888
889 func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
890 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
891 if result.Err != nil {
892 return result.Err
893 }
894
895 if string(result.Rows[0][0]) != "t" {
896 return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
897 }
898
899 return nil
900 }
901
View as plain text