1 package pgconn
2
3 import (
4 "context"
5 "crypto/tls"
6 "crypto/x509"
7 "encoding/pem"
8 "errors"
9 "fmt"
10 "io"
11 "io/ioutil"
12 "math"
13 "net"
14 "net/url"
15 "os"
16 "path/filepath"
17 "strconv"
18 "strings"
19 "time"
20
21 "github.com/jackc/chunkreader/v2"
22 "github.com/jackc/pgpassfile"
23 "github.com/jackc/pgproto3/v2"
24 "github.com/jackc/pgservicefile"
25 )
26
27 type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
28 type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
29 type GetSSLPasswordFunc func(ctx context.Context) string
30
31
32
33 type Config struct {
34 Host string
35 Port uint16
36 Database string
37 User string
38 Password string
39 TLSConfig *tls.Config
40 ConnectTimeout time.Duration
41 DialFunc DialFunc
42 LookupFunc LookupFunc
43 BuildFrontend BuildFrontendFunc
44 RuntimeParams map[string]string
45
46 KerberosSrvName string
47 KerberosSpn string
48 Fallbacks []*FallbackConfig
49
50
51
52
53 ValidateConnect ValidateConnectFunc
54
55
56
57 AfterConnect AfterConnectFunc
58
59
60 OnNotice NoticeHandler
61
62
63 OnNotification NotificationHandler
64
65 createdByParseConfig bool
66 }
67
68
69 type ParseConfigOptions struct {
70
71
72 GetSSLPassword GetSSLPasswordFunc
73 }
74
75
76
77
78 func (c *Config) Copy() *Config {
79 newConf := new(Config)
80 *newConf = *c
81 if newConf.TLSConfig != nil {
82 newConf.TLSConfig = c.TLSConfig.Clone()
83 }
84 if newConf.RuntimeParams != nil {
85 newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
86 for k, v := range c.RuntimeParams {
87 newConf.RuntimeParams[k] = v
88 }
89 }
90 if newConf.Fallbacks != nil {
91 newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
92 for i, fallback := range c.Fallbacks {
93 newFallback := new(FallbackConfig)
94 *newFallback = *fallback
95 if newFallback.TLSConfig != nil {
96 newFallback.TLSConfig = fallback.TLSConfig.Clone()
97 }
98 newConf.Fallbacks[i] = newFallback
99 }
100 }
101 return newConf
102 }
103
104
105
106 type FallbackConfig struct {
107 Host string
108 Port uint16
109 TLSConfig *tls.Config
110 }
111
112
113
114
115 func isAbsolutePath(path string) bool {
116 isWindowsPath := func(p string) bool {
117 if len(p) < 3 {
118 return false
119 }
120 drive := p[0]
121 colon := p[1]
122 backslash := p[2]
123 if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
124 return true
125 }
126 return false
127 }
128 return strings.HasPrefix(path, "/") || isWindowsPath(path)
129 }
130
131
132
133 func NetworkAddress(host string, port uint16) (network, address string) {
134 if isAbsolutePath(host) {
135 network = "unix"
136 address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
137 } else {
138 network = "tcp"
139 address = net.JoinHostPort(host, strconv.Itoa(int(port)))
140 }
141 return network, address
142 }
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220 func ParseConfig(connString string) (*Config, error) {
221 var parseConfigOptions ParseConfigOptions
222 return ParseConfigWithOptions(connString, parseConfigOptions)
223 }
224
225
226
227
228 func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
229 defaultSettings := defaultSettings()
230 envSettings := parseEnvSettings()
231
232 connStringSettings := make(map[string]string)
233 if connString != "" {
234 var err error
235
236 if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
237 connStringSettings, err = parseURLSettings(connString)
238 if err != nil {
239 return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
240 }
241 } else {
242 connStringSettings, err = parseDSNSettings(connString)
243 if err != nil {
244 return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
245 }
246 }
247 }
248
249 settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
250 if service, present := settings["service"]; present {
251 serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
252 if err != nil {
253 return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
254 }
255
256 settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
257 }
258
259 minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
260 if err != nil {
261 return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
262 }
263
264 config := &Config{
265 createdByParseConfig: true,
266 Database: settings["database"],
267 User: settings["user"],
268 Password: settings["password"],
269 RuntimeParams: make(map[string]string),
270 BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
271 }
272
273 if connectTimeoutSetting, present := settings["connect_timeout"]; present {
274 connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
275 if err != nil {
276 return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
277 }
278 config.ConnectTimeout = connectTimeout
279 config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
280 } else {
281 defaultDialer := makeDefaultDialer()
282 config.DialFunc = defaultDialer.DialContext
283 }
284
285 config.LookupFunc = makeDefaultResolver().LookupHost
286
287 notRuntimeParams := map[string]struct{}{
288 "host": {},
289 "port": {},
290 "database": {},
291 "user": {},
292 "password": {},
293 "passfile": {},
294 "connect_timeout": {},
295 "sslmode": {},
296 "sslkey": {},
297 "sslcert": {},
298 "sslrootcert": {},
299 "sslpassword": {},
300 "sslsni": {},
301 "krbspn": {},
302 "krbsrvname": {},
303 "target_session_attrs": {},
304 "min_read_buffer_size": {},
305 "service": {},
306 "servicefile": {},
307 }
308
309
310 if _, present := settings["krbsrvname"]; present {
311 config.KerberosSrvName = settings["krbsrvname"]
312 }
313 if _, present := settings["krbspn"]; present {
314 config.KerberosSpn = settings["krbspn"]
315 }
316
317 for k, v := range settings {
318 if _, present := notRuntimeParams[k]; present {
319 continue
320 }
321 config.RuntimeParams[k] = v
322 }
323
324 fallbacks := []*FallbackConfig{}
325
326 hosts := strings.Split(settings["host"], ",")
327 ports := strings.Split(settings["port"], ",")
328
329 for i, host := range hosts {
330 var portStr string
331 if i < len(ports) {
332 portStr = ports[i]
333 } else {
334 portStr = ports[0]
335 }
336
337 port, err := parsePort(portStr)
338 if err != nil {
339 return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
340 }
341
342 var tlsConfigs []*tls.Config
343
344
345 if network, _ := NetworkAddress(host, port); network == "unix" {
346 tlsConfigs = append(tlsConfigs, nil)
347 } else {
348 var err error
349 tlsConfigs, err = configTLS(settings, host, options)
350 if err != nil {
351 return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
352 }
353 }
354
355 for _, tlsConfig := range tlsConfigs {
356 fallbacks = append(fallbacks, &FallbackConfig{
357 Host: host,
358 Port: port,
359 TLSConfig: tlsConfig,
360 })
361 }
362 }
363
364 config.Host = fallbacks[0].Host
365 config.Port = fallbacks[0].Port
366 config.TLSConfig = fallbacks[0].TLSConfig
367 config.Fallbacks = fallbacks[1:]
368
369 if config.Password == "" {
370 passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
371 if err == nil {
372 host := config.Host
373 if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
374 host = "localhost"
375 }
376
377 config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
378 }
379 }
380
381 switch tsa := settings["target_session_attrs"]; tsa {
382 case "read-write":
383 config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
384 case "read-only":
385 config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
386 case "primary":
387 config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
388 case "standby":
389 config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
390 case "prefer-standby":
391 config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
392 case "any":
393
394 default:
395 return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
396 }
397
398 return config, nil
399 }
400
401 func mergeSettings(settingSets ...map[string]string) map[string]string {
402 settings := make(map[string]string)
403
404 for _, s2 := range settingSets {
405 for k, v := range s2 {
406 settings[k] = v
407 }
408 }
409
410 return settings
411 }
412
413 func parseEnvSettings() map[string]string {
414 settings := make(map[string]string)
415
416 nameMap := map[string]string{
417 "PGHOST": "host",
418 "PGPORT": "port",
419 "PGDATABASE": "database",
420 "PGUSER": "user",
421 "PGPASSWORD": "password",
422 "PGPASSFILE": "passfile",
423 "PGAPPNAME": "application_name",
424 "PGCONNECT_TIMEOUT": "connect_timeout",
425 "PGSSLMODE": "sslmode",
426 "PGSSLKEY": "sslkey",
427 "PGSSLCERT": "sslcert",
428 "PGSSLSNI": "sslsni",
429 "PGSSLROOTCERT": "sslrootcert",
430 "PGSSLPASSWORD": "sslpassword",
431 "PGTARGETSESSIONATTRS": "target_session_attrs",
432 "PGSERVICE": "service",
433 "PGSERVICEFILE": "servicefile",
434 }
435
436 for envname, realname := range nameMap {
437 value := os.Getenv(envname)
438 if value != "" {
439 settings[realname] = value
440 }
441 }
442
443 return settings
444 }
445
446 func parseURLSettings(connString string) (map[string]string, error) {
447 settings := make(map[string]string)
448
449 url, err := url.Parse(connString)
450 if err != nil {
451 return nil, err
452 }
453
454 if url.User != nil {
455 settings["user"] = url.User.Username()
456 if password, present := url.User.Password(); present {
457 settings["password"] = password
458 }
459 }
460
461
462 var hosts []string
463 var ports []string
464 for _, host := range strings.Split(url.Host, ",") {
465 if host == "" {
466 continue
467 }
468 if isIPOnly(host) {
469 hosts = append(hosts, strings.Trim(host, "[]"))
470 continue
471 }
472 h, p, err := net.SplitHostPort(host)
473 if err != nil {
474 return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
475 }
476 if h != "" {
477 hosts = append(hosts, h)
478 }
479 if p != "" {
480 ports = append(ports, p)
481 }
482 }
483 if len(hosts) > 0 {
484 settings["host"] = strings.Join(hosts, ",")
485 }
486 if len(ports) > 0 {
487 settings["port"] = strings.Join(ports, ",")
488 }
489
490 database := strings.TrimLeft(url.Path, "/")
491 if database != "" {
492 settings["database"] = database
493 }
494
495 nameMap := map[string]string{
496 "dbname": "database",
497 }
498
499 for k, v := range url.Query() {
500 if k2, present := nameMap[k]; present {
501 k = k2
502 }
503
504 settings[k] = v[0]
505 }
506
507 return settings, nil
508 }
509
510 func isIPOnly(host string) bool {
511 return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
512 }
513
514 var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
515
516 func parseDSNSettings(s string) (map[string]string, error) {
517 settings := make(map[string]string)
518
519 nameMap := map[string]string{
520 "dbname": "database",
521 }
522
523 for len(s) > 0 {
524 var key, val string
525 eqIdx := strings.IndexRune(s, '=')
526 if eqIdx < 0 {
527 return nil, errors.New("invalid dsn")
528 }
529
530 key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
531 s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
532 if len(s) == 0 {
533 } else if s[0] != '\'' {
534 end := 0
535 for ; end < len(s); end++ {
536 if asciiSpace[s[end]] == 1 {
537 break
538 }
539 if s[end] == '\\' {
540 end++
541 if end == len(s) {
542 return nil, errors.New("invalid backslash")
543 }
544 }
545 }
546 val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
547 if end == len(s) {
548 s = ""
549 } else {
550 s = s[end+1:]
551 }
552 } else {
553 s = s[1:]
554 end := 0
555 for ; end < len(s); end++ {
556 if s[end] == '\'' {
557 break
558 }
559 if s[end] == '\\' {
560 end++
561 }
562 }
563 if end == len(s) {
564 return nil, errors.New("unterminated quoted string in connection info string")
565 }
566 val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
567 if end == len(s) {
568 s = ""
569 } else {
570 s = s[end+1:]
571 }
572 }
573
574 if k, ok := nameMap[key]; ok {
575 key = k
576 }
577
578 if key == "" {
579 return nil, errors.New("invalid dsn")
580 }
581
582 settings[key] = val
583 }
584
585 return settings, nil
586 }
587
588 func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
589 servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
590 if err != nil {
591 return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
592 }
593
594 service, err := servicefile.GetService(serviceName)
595 if err != nil {
596 return nil, fmt.Errorf("unable to find service: %v", serviceName)
597 }
598
599 nameMap := map[string]string{
600 "dbname": "database",
601 }
602
603 settings := make(map[string]string, len(service.Settings))
604 for k, v := range service.Settings {
605 if k2, present := nameMap[k]; present {
606 k = k2
607 }
608 settings[k] = v
609 }
610
611 return settings, nil
612 }
613
614
615
616
617 func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
618 host := thisHost
619 sslmode := settings["sslmode"]
620 sslrootcert := settings["sslrootcert"]
621 sslcert := settings["sslcert"]
622 sslkey := settings["sslkey"]
623 sslpassword := settings["sslpassword"]
624 sslsni := settings["sslsni"]
625
626
627 if sslmode == "" {
628 sslmode = "prefer"
629 }
630 if sslsni == "" {
631 sslsni = "1"
632 }
633
634 tlsConfig := &tls.Config{}
635
636 switch sslmode {
637 case "disable":
638 return []*tls.Config{nil}, nil
639 case "allow", "prefer":
640 tlsConfig.InsecureSkipVerify = true
641 case "require":
642
643
644
645
646 if sslrootcert != "" {
647 goto nextCase
648 }
649 tlsConfig.InsecureSkipVerify = true
650 break
651 nextCase:
652 fallthrough
653 case "verify-ca":
654
655
656
657
658
659
660
661
662
663 tlsConfig.InsecureSkipVerify = true
664 tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
665 certs := make([]*x509.Certificate, len(certificates))
666 for i, asn1Data := range certificates {
667 cert, err := x509.ParseCertificate(asn1Data)
668 if err != nil {
669 return errors.New("failed to parse certificate from server: " + err.Error())
670 }
671 certs[i] = cert
672 }
673
674
675 opts := x509.VerifyOptions{
676 Roots: tlsConfig.RootCAs,
677 Intermediates: x509.NewCertPool(),
678 }
679
680
681 for _, cert := range certs[1:] {
682 opts.Intermediates.AddCert(cert)
683 }
684 _, err := certs[0].Verify(opts)
685 return err
686 }
687 case "verify-full":
688 tlsConfig.ServerName = host
689 default:
690 return nil, errors.New("sslmode is invalid")
691 }
692
693 if sslrootcert != "" {
694 caCertPool := x509.NewCertPool()
695
696 caPath := sslrootcert
697 caCert, err := ioutil.ReadFile(caPath)
698 if err != nil {
699 return nil, fmt.Errorf("unable to read CA file: %w", err)
700 }
701
702 if !caCertPool.AppendCertsFromPEM(caCert) {
703 return nil, errors.New("unable to add CA to cert pool")
704 }
705
706 tlsConfig.RootCAs = caCertPool
707 tlsConfig.ClientCAs = caCertPool
708 }
709
710 if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
711 return nil, errors.New(`both "sslcert" and "sslkey" are required`)
712 }
713
714 if sslcert != "" && sslkey != "" {
715 buf, err := ioutil.ReadFile(sslkey)
716 if err != nil {
717 return nil, fmt.Errorf("unable to read sslkey: %w", err)
718 }
719 block, _ := pem.Decode(buf)
720 var pemKey []byte
721 var decryptedKey []byte
722 var decryptedError error
723
724 if x509.IsEncryptedPEMBlock(block) {
725
726
727 if sslpassword != "" {
728 decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
729 }
730
731
732 if sslpassword == "" || decryptedError != nil {
733 if parseConfigOptions.GetSSLPassword != nil {
734 sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
735 }
736 if sslpassword == "" {
737 return nil, fmt.Errorf("unable to find sslpassword")
738 }
739 }
740 decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
741
742 if decryptedError != nil {
743 return nil, fmt.Errorf("unable to decrypt key: %w", err)
744 }
745
746 pemBytes := pem.Block{
747 Type: "RSA PRIVATE KEY",
748 Bytes: decryptedKey,
749 }
750 pemKey = pem.EncodeToMemory(&pemBytes)
751 } else {
752 pemKey = pem.EncodeToMemory(block)
753 }
754 certfile, err := ioutil.ReadFile(sslcert)
755 if err != nil {
756 return nil, fmt.Errorf("unable to read cert: %w", err)
757 }
758 cert, err := tls.X509KeyPair(certfile, pemKey)
759 if err != nil {
760 return nil, fmt.Errorf("unable to load cert: %w", err)
761 }
762 tlsConfig.Certificates = []tls.Certificate{cert}
763 }
764
765
766
767
768 if sslsni == "1" && net.ParseIP(host) == nil {
769 tlsConfig.ServerName = host
770 }
771
772 switch sslmode {
773 case "allow":
774 return []*tls.Config{nil, tlsConfig}, nil
775 case "prefer":
776 return []*tls.Config{tlsConfig, nil}, nil
777 case "require", "verify-ca", "verify-full":
778 return []*tls.Config{tlsConfig}, nil
779 default:
780 panic("BUG: bad sslmode should already have been caught")
781 }
782 }
783
784 func parsePort(s string) (uint16, error) {
785 port, err := strconv.ParseUint(s, 10, 16)
786 if err != nil {
787 return 0, err
788 }
789 if port < 1 || port > math.MaxUint16 {
790 return 0, errors.New("outside range")
791 }
792 return uint16(port), nil
793 }
794
795 func makeDefaultDialer() *net.Dialer {
796 return &net.Dialer{KeepAlive: 5 * time.Minute}
797 }
798
799 func makeDefaultResolver() *net.Resolver {
800 return net.DefaultResolver
801 }
802
803 func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
804 return func(r io.Reader, w io.Writer) Frontend {
805 cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
806 if err != nil {
807 panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
808 }
809 frontend := pgproto3.NewFrontend(cr, w)
810
811 return frontend
812 }
813 }
814
815 func parseConnectTimeoutSetting(s string) (time.Duration, error) {
816 timeout, err := strconv.ParseInt(s, 10, 64)
817 if err != nil {
818 return 0, err
819 }
820 if timeout < 0 {
821 return 0, errors.New("negative timeout")
822 }
823 return time.Duration(timeout) * time.Second, nil
824 }
825
826 func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
827 d := makeDefaultDialer()
828 d.Timeout = timeout
829 return d.DialContext
830 }
831
832
833
834 func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
835 result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
836 if result.Err != nil {
837 return result.Err
838 }
839
840 if string(result.Rows[0][0]) == "on" {
841 return errors.New("read only connection")
842 }
843
844 return nil
845 }
846
847
848
849 func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
850 result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
851 if result.Err != nil {
852 return result.Err
853 }
854
855 if string(result.Rows[0][0]) != "on" {
856 return errors.New("connection is not read only")
857 }
858
859 return nil
860 }
861
862
863
864 func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
865 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
866 if result.Err != nil {
867 return result.Err
868 }
869
870 if string(result.Rows[0][0]) != "t" {
871 return errors.New("server is not in hot standby mode")
872 }
873
874 return nil
875 }
876
877
878
879 func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
880 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
881 if result.Err != nil {
882 return result.Err
883 }
884
885 if string(result.Rows[0][0]) == "t" {
886 return errors.New("server is in standby mode")
887 }
888
889 return nil
890 }
891
892
893
894 func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
895 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
896 if result.Err != nil {
897 return result.Err
898 }
899
900 if string(result.Rows[0][0]) != "t" {
901 return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
902 }
903
904 return nil
905 }
906
View as plain text