1 package pgconn
2
3 import (
4 "context"
5 "crypto/md5"
6 "crypto/tls"
7 "encoding/binary"
8 "encoding/hex"
9 "errors"
10 "fmt"
11 "io"
12 "math"
13 "net"
14 "strconv"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/jackc/pgx/v5/internal/iobufpool"
20 "github.com/jackc/pgx/v5/internal/pgio"
21 "github.com/jackc/pgx/v5/pgconn/internal/bgreader"
22 "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
23 "github.com/jackc/pgx/v5/pgproto3"
24 )
25
26 const (
27 connStatusUninitialized = iota
28 connStatusConnecting
29 connStatusClosed
30 connStatusIdle
31 connStatusBusy
32 )
33
34
35
36 type Notice PgError
37
38
39 type Notification struct {
40 PID uint32
41 Channel string
42 Payload string
43 }
44
45
46 type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
47
48
49
50 type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
51
52
53 type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend
54
55
56
57
58
59 type PgErrorHandler func(*PgConn, *PgError) bool
60
61
62
63
64
65 type NoticeHandler func(*PgConn, *Notice)
66
67
68
69
70
71 type NotificationHandler func(*PgConn, *Notification)
72
73
74 type PgConn struct {
75 conn net.Conn
76 pid uint32
77 secretKey uint32
78 parameterStatuses map[string]string
79 txStatus byte
80 frontend *pgproto3.Frontend
81 bgReader *bgreader.BGReader
82 slowWriteTimer *time.Timer
83 bgReaderStarted chan struct{}
84
85 config *Config
86
87 status byte
88
89 bufferingReceive bool
90 bufferingReceiveMux sync.Mutex
91 bufferingReceiveMsg pgproto3.BackendMessage
92 bufferingReceiveErr error
93
94 peekedMsg pgproto3.BackendMessage
95
96
97 resultReader ResultReader
98 multiResultReader MultiResultReader
99 pipeline Pipeline
100 contextWatcher *ctxwatch.ContextWatcher
101 fieldDescriptions [16]FieldDescription
102
103 cleanupDone chan struct{}
104 }
105
106
107
108 func Connect(ctx context.Context, connString string) (*PgConn, error) {
109 config, err := ParseConfig(connString)
110 if err != nil {
111 return nil, err
112 }
113
114 return ConnectConfig(ctx, config)
115 }
116
117
118
119
120 func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) {
121 config, err := ParseConfigWithOptions(connString, parseConfigOptions)
122 if err != nil {
123 return nil, err
124 }
125
126 return ConnectConfig(ctx, config)
127 }
128
129
130
131
132
133
134
135
136 func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) {
137
138
139 if !config.createdByParseConfig {
140 panic("config must be created by ParseConfig")
141 }
142
143
144 fallbackConfigs := []*FallbackConfig{
145 {
146 Host: config.Host,
147 Port: config.Port,
148 TLSConfig: config.TLSConfig,
149 },
150 }
151 fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
152 ctx := octx
153 fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
154 if err != nil {
155 return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err}
156 }
157
158 if len(fallbackConfigs) == 0 {
159 return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
160 }
161
162 foundBestServer := false
163 var fallbackConfig *FallbackConfig
164 for i, fc := range fallbackConfigs {
165
166 if config.ConnectTimeout != 0 {
167
168 if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) {
169 var cancel context.CancelFunc
170 ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout)
171 defer cancel()
172 }
173 } else {
174 ctx = octx
175 }
176 pgConn, err = connect(ctx, config, fc, false)
177 if err == nil {
178 foundBestServer = true
179 break
180 } else if pgerr, ok := err.(*PgError); ok {
181 err = &ConnectError{Config: config, msg: "server error", err: pgerr}
182 const ERRCODE_INVALID_PASSWORD = "28P01"
183 const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000"
184 const ERRCODE_INVALID_CATALOG_NAME = "3D000"
185 const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501"
186 if pgerr.Code == ERRCODE_INVALID_PASSWORD ||
187 pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil ||
188 pgerr.Code == ERRCODE_INVALID_CATALOG_NAME ||
189 pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
190 break
191 }
192 } else if cerr, ok := err.(*ConnectError); ok {
193 if _, ok := cerr.err.(*NotPreferredError); ok {
194 fallbackConfig = fc
195 }
196 }
197 }
198
199 if !foundBestServer && fallbackConfig != nil {
200 pgConn, err = connect(ctx, config, fallbackConfig, true)
201 if pgerr, ok := err.(*PgError); ok {
202 err = &ConnectError{Config: config, msg: "server error", err: pgerr}
203 }
204 }
205
206 if err != nil {
207 return nil, err
208 }
209
210 if config.AfterConnect != nil {
211 err := config.AfterConnect(ctx, pgConn)
212 if err != nil {
213 pgConn.conn.Close()
214 return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err}
215 }
216 }
217
218 return pgConn, nil
219 }
220
221 func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) {
222 var configs []*FallbackConfig
223
224 var lookupErrors []error
225
226 for _, fb := range fallbacks {
227
228 if isAbsolutePath(fb.Host) {
229 configs = append(configs, &FallbackConfig{
230 Host: fb.Host,
231 Port: fb.Port,
232 TLSConfig: fb.TLSConfig,
233 })
234
235 continue
236 }
237
238 ips, err := lookupFn(ctx, fb.Host)
239 if err != nil {
240 lookupErrors = append(lookupErrors, err)
241 continue
242 }
243
244 for _, ip := range ips {
245 splitIP, splitPort, err := net.SplitHostPort(ip)
246 if err == nil {
247 port, err := strconv.ParseUint(splitPort, 10, 16)
248 if err != nil {
249 return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)
250 }
251 configs = append(configs, &FallbackConfig{
252 Host: splitIP,
253 Port: uint16(port),
254 TLSConfig: fb.TLSConfig,
255 })
256 } else {
257 configs = append(configs, &FallbackConfig{
258 Host: ip,
259 Port: fb.Port,
260 TLSConfig: fb.TLSConfig,
261 })
262 }
263 }
264 }
265
266
267
268 if len(configs) == 0 && len(lookupErrors) > 0 {
269 return nil, lookupErrors[0]
270 }
271
272 return configs, nil
273 }
274
275 func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig,
276 ignoreNotPreferredErr bool,
277 ) (*PgConn, error) {
278 pgConn := new(PgConn)
279 pgConn.config = config
280 pgConn.cleanupDone = make(chan struct{})
281
282 var err error
283 network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
284 netConn, err := config.DialFunc(ctx, network, address)
285 if err != nil {
286 return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
287 }
288
289 pgConn.conn = netConn
290 pgConn.contextWatcher = newContextWatcher(netConn)
291 pgConn.contextWatcher.Watch(ctx)
292
293 if fallbackConfig.TLSConfig != nil {
294 nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
295 pgConn.contextWatcher.Unwatch()
296 if err != nil {
297 netConn.Close()
298 return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
299 }
300
301 pgConn.conn = nbTLSConn
302 pgConn.contextWatcher = newContextWatcher(nbTLSConn)
303 pgConn.contextWatcher.Watch(ctx)
304 }
305
306 defer pgConn.contextWatcher.Unwatch()
307
308 pgConn.parameterStatuses = make(map[string]string)
309 pgConn.status = connStatusConnecting
310 pgConn.bgReader = bgreader.New(pgConn.conn)
311 pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
312 func() {
313 pgConn.bgReader.Start()
314 pgConn.bgReaderStarted <- struct{}{}
315 },
316 )
317 pgConn.slowWriteTimer.Stop()
318 pgConn.bgReaderStarted = make(chan struct{})
319 pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)
320
321 startupMsg := pgproto3.StartupMessage{
322 ProtocolVersion: pgproto3.ProtocolVersionNumber,
323 Parameters: make(map[string]string),
324 }
325
326
327 for k, v := range config.RuntimeParams {
328 startupMsg.Parameters[k] = v
329 }
330
331 startupMsg.Parameters["user"] = config.User
332 if config.Database != "" {
333 startupMsg.Parameters["database"] = config.Database
334 }
335
336 pgConn.frontend.Send(&startupMsg)
337 if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
338 pgConn.conn.Close()
339 return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
340 }
341
342 for {
343 msg, err := pgConn.receiveMessage()
344 if err != nil {
345 pgConn.conn.Close()
346 if err, ok := err.(*PgError); ok {
347 return nil, err
348 }
349 return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
350 }
351
352 switch msg := msg.(type) {
353 case *pgproto3.BackendKeyData:
354 pgConn.pid = msg.ProcessID
355 pgConn.secretKey = msg.SecretKey
356
357 case *pgproto3.AuthenticationOk:
358 case *pgproto3.AuthenticationCleartextPassword:
359 err = pgConn.txPasswordMessage(pgConn.config.Password)
360 if err != nil {
361 pgConn.conn.Close()
362 return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
363 }
364 case *pgproto3.AuthenticationMD5Password:
365 digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
366 err = pgConn.txPasswordMessage(digestedPassword)
367 if err != nil {
368 pgConn.conn.Close()
369 return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
370 }
371 case *pgproto3.AuthenticationSASL:
372 err = pgConn.scramAuth(msg.AuthMechanisms)
373 if err != nil {
374 pgConn.conn.Close()
375 return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err}
376 }
377 case *pgproto3.AuthenticationGSS:
378 err = pgConn.gssAuth()
379 if err != nil {
380 pgConn.conn.Close()
381 return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err}
382 }
383 case *pgproto3.ReadyForQuery:
384 pgConn.status = connStatusIdle
385 if config.ValidateConnect != nil {
386
387
388
389
390
391 pgConn.contextWatcher.Unwatch()
392
393 err := config.ValidateConnect(ctx, pgConn)
394 if err != nil {
395 if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok {
396 return pgConn, nil
397 }
398 pgConn.conn.Close()
399 return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err}
400 }
401 }
402 return pgConn, nil
403 case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse:
404
405 case *pgproto3.ErrorResponse:
406 pgConn.conn.Close()
407 return nil, ErrorResponseToPgError(msg)
408 default:
409 pgConn.conn.Close()
410 return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err}
411 }
412 }
413 }
414
415 func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
416 return ctxwatch.NewContextWatcher(
417 func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
418 func() { conn.SetDeadline(time.Time{}) },
419 )
420 }
421
422 func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
423 err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
424 if err != nil {
425 return nil, err
426 }
427
428 response := make([]byte, 1)
429 if _, err = io.ReadFull(conn, response); err != nil {
430 return nil, err
431 }
432
433 if response[0] != 'S' {
434 return nil, errors.New("server refused TLS connection")
435 }
436
437 return tls.Client(conn, tlsConfig), nil
438 }
439
440 func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
441 pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password})
442 return pgConn.flushWithPotentialWriteReadDeadlock()
443 }
444
445 func hexMD5(s string) string {
446 hash := md5.New()
447 io.WriteString(hash, s)
448 return hex.EncodeToString(hash.Sum(nil))
449 }
450
451 func (pgConn *PgConn) signalMessage() chan struct{} {
452 if pgConn.bufferingReceive {
453 panic("BUG: signalMessage when already in progress")
454 }
455
456 pgConn.bufferingReceive = true
457 pgConn.bufferingReceiveMux.Lock()
458
459 ch := make(chan struct{})
460 go func() {
461 pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
462 pgConn.bufferingReceiveMux.Unlock()
463 close(ch)
464 }()
465
466 return ch
467 }
468
469
470
471
472
473
474
475
476 func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
477 if err := pgConn.lock(); err != nil {
478 return nil, err
479 }
480 defer pgConn.unlock()
481
482 if ctx != context.Background() {
483 select {
484 case <-ctx.Done():
485 return nil, newContextAlreadyDoneError(ctx)
486 default:
487 }
488 pgConn.contextWatcher.Watch(ctx)
489 defer pgConn.contextWatcher.Unwatch()
490 }
491
492 msg, err := pgConn.receiveMessage()
493 if err != nil {
494 err = &pgconnError{
495 msg: "receive message failed",
496 err: normalizeTimeoutError(ctx, err),
497 safeToRetry: true,
498 }
499 }
500 return msg, err
501 }
502
503
504 func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
505 if pgConn.peekedMsg != nil {
506 return pgConn.peekedMsg, nil
507 }
508
509 var msg pgproto3.BackendMessage
510 var err error
511 if pgConn.bufferingReceive {
512 pgConn.bufferingReceiveMux.Lock()
513 msg = pgConn.bufferingReceiveMsg
514 err = pgConn.bufferingReceiveErr
515 pgConn.bufferingReceiveMux.Unlock()
516 pgConn.bufferingReceive = false
517
518
519 var netErr net.Error
520 if errors.As(err, &netErr) && netErr.Timeout() {
521 msg, err = pgConn.frontend.Receive()
522 }
523 } else {
524 msg, err = pgConn.frontend.Receive()
525 }
526
527 if err != nil {
528
529 var netErr net.Error
530 isNetErr := errors.As(err, &netErr)
531 if !(isNetErr && netErr.Timeout()) {
532 pgConn.asyncClose()
533 }
534
535 return nil, err
536 }
537
538 pgConn.peekedMsg = msg
539 return msg, nil
540 }
541
542
543 func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
544 msg, err := pgConn.peekMessage()
545 if err != nil {
546 return nil, err
547 }
548 pgConn.peekedMsg = nil
549
550 switch msg := msg.(type) {
551 case *pgproto3.ReadyForQuery:
552 pgConn.txStatus = msg.TxStatus
553 case *pgproto3.ParameterStatus:
554 pgConn.parameterStatuses[msg.Name] = msg.Value
555 case *pgproto3.ErrorResponse:
556 err := ErrorResponseToPgError(msg)
557 if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) {
558 pgConn.status = connStatusClosed
559 pgConn.conn.Close()
560 close(pgConn.cleanupDone)
561 return nil, err
562 }
563 case *pgproto3.NoticeResponse:
564 if pgConn.config.OnNotice != nil {
565 pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg))
566 }
567 case *pgproto3.NotificationResponse:
568 if pgConn.config.OnNotification != nil {
569 pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload})
570 }
571 }
572
573 return msg, nil
574 }
575
576
577
578 func (pgConn *PgConn) Conn() net.Conn {
579 return pgConn.conn
580 }
581
582
583 func (pgConn *PgConn) PID() uint32 {
584 return pgConn.pid
585 }
586
587
588
589
590
591
592
593
594
595
596 func (pgConn *PgConn) TxStatus() byte {
597 return pgConn.txStatus
598 }
599
600
601 func (pgConn *PgConn) SecretKey() uint32 {
602 return pgConn.secretKey
603 }
604
605
606 func (pgConn *PgConn) Frontend() *pgproto3.Frontend {
607 return pgConn.frontend
608 }
609
610
611
612
613 func (pgConn *PgConn) Close(ctx context.Context) error {
614 if pgConn.status == connStatusClosed {
615 return nil
616 }
617 pgConn.status = connStatusClosed
618
619 defer close(pgConn.cleanupDone)
620 defer pgConn.conn.Close()
621
622 if ctx != context.Background() {
623
624
625
626
627
628 pgConn.contextWatcher.Unwatch()
629
630 pgConn.contextWatcher.Watch(ctx)
631 defer pgConn.contextWatcher.Unwatch()
632 }
633
634
635
636
637
638
639 pgConn.frontend.Send(&pgproto3.Terminate{})
640 pgConn.flushWithPotentialWriteReadDeadlock()
641
642 return pgConn.conn.Close()
643 }
644
645
646
647 func (pgConn *PgConn) asyncClose() {
648 if pgConn.status == connStatusClosed {
649 return
650 }
651 pgConn.status = connStatusClosed
652
653 go func() {
654 defer close(pgConn.cleanupDone)
655 defer pgConn.conn.Close()
656
657 deadline := time.Now().Add(time.Second * 15)
658
659 ctx, cancel := context.WithDeadline(context.Background(), deadline)
660 defer cancel()
661
662 pgConn.CancelRequest(ctx)
663
664 pgConn.conn.SetDeadline(deadline)
665
666 pgConn.frontend.Send(&pgproto3.Terminate{})
667 pgConn.flushWithPotentialWriteReadDeadlock()
668 }()
669 }
670
671
672
673
674
675
676
677
678
679 func (pgConn *PgConn) CleanupDone() chan (struct{}) {
680 return pgConn.cleanupDone
681 }
682
683
684
685
686 func (pgConn *PgConn) IsClosed() bool {
687 return pgConn.status < connStatusIdle
688 }
689
690
691 func (pgConn *PgConn) IsBusy() bool {
692 return pgConn.status == connStatusBusy
693 }
694
695
696 func (pgConn *PgConn) lock() error {
697 switch pgConn.status {
698 case connStatusBusy:
699 return &connLockError{status: "conn busy"}
700 case connStatusClosed:
701 return &connLockError{status: "conn closed"}
702 case connStatusUninitialized:
703 return &connLockError{status: "conn uninitialized"}
704 }
705 pgConn.status = connStatusBusy
706 return nil
707 }
708
709 func (pgConn *PgConn) unlock() {
710 switch pgConn.status {
711 case connStatusBusy:
712 pgConn.status = connStatusIdle
713 case connStatusClosed:
714 default:
715 panic("BUG: cannot unlock unlocked connection")
716 }
717 }
718
719
720
721 func (pgConn *PgConn) ParameterStatus(key string) string {
722 return pgConn.parameterStatuses[key]
723 }
724
725
726 type CommandTag struct {
727 s string
728 }
729
730
731 func NewCommandTag(s string) CommandTag {
732 return CommandTag{s: s}
733 }
734
735
736
737 func (ct CommandTag) RowsAffected() int64 {
738
739 idx := -1
740 for i := len(ct.s) - 1; i >= 0; i-- {
741 if ct.s[i] >= '0' && ct.s[i] <= '9' {
742 idx = i
743 } else {
744 break
745 }
746 }
747
748 if idx == -1 {
749 return 0
750 }
751
752 var n int64
753 for _, b := range ct.s[idx:] {
754 n = n*10 + int64(b-'0')
755 }
756
757 return n
758 }
759
760 func (ct CommandTag) String() string {
761 return ct.s
762 }
763
764
765 func (ct CommandTag) Insert() bool {
766 return strings.HasPrefix(ct.s, "INSERT")
767 }
768
769
770 func (ct CommandTag) Update() bool {
771 return strings.HasPrefix(ct.s, "UPDATE")
772 }
773
774
775 func (ct CommandTag) Delete() bool {
776 return strings.HasPrefix(ct.s, "DELETE")
777 }
778
779
780 func (ct CommandTag) Select() bool {
781 return strings.HasPrefix(ct.s, "SELECT")
782 }
783
784 type FieldDescription struct {
785 Name string
786 TableOID uint32
787 TableAttributeNumber uint16
788 DataTypeOID uint32
789 DataTypeSize int16
790 TypeModifier int32
791 Format int16
792 }
793
794 func (pgConn *PgConn) convertRowDescription(dst []FieldDescription, rd *pgproto3.RowDescription) []FieldDescription {
795 if cap(dst) >= len(rd.Fields) {
796 dst = dst[:len(rd.Fields):len(rd.Fields)]
797 } else {
798 dst = make([]FieldDescription, len(rd.Fields))
799 }
800
801 for i := range rd.Fields {
802 dst[i].Name = string(rd.Fields[i].Name)
803 dst[i].TableOID = rd.Fields[i].TableOID
804 dst[i].TableAttributeNumber = rd.Fields[i].TableAttributeNumber
805 dst[i].DataTypeOID = rd.Fields[i].DataTypeOID
806 dst[i].DataTypeSize = rd.Fields[i].DataTypeSize
807 dst[i].TypeModifier = rd.Fields[i].TypeModifier
808 dst[i].Format = rd.Fields[i].Format
809 }
810
811 return dst
812 }
813
814 type StatementDescription struct {
815 Name string
816 SQL string
817 ParamOIDs []uint32
818 Fields []FieldDescription
819 }
820
821
822
823
824
825
826 func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
827 if err := pgConn.lock(); err != nil {
828 return nil, err
829 }
830 defer pgConn.unlock()
831
832 if ctx != context.Background() {
833 select {
834 case <-ctx.Done():
835 return nil, newContextAlreadyDoneError(ctx)
836 default:
837 }
838 pgConn.contextWatcher.Watch(ctx)
839 defer pgConn.contextWatcher.Unwatch()
840 }
841
842 pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
843 pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
844 pgConn.frontend.SendSync(&pgproto3.Sync{})
845 err := pgConn.flushWithPotentialWriteReadDeadlock()
846 if err != nil {
847 pgConn.asyncClose()
848 return nil, err
849 }
850
851 psd := &StatementDescription{Name: name, SQL: sql}
852
853 var parseErr error
854
855 readloop:
856 for {
857 msg, err := pgConn.receiveMessage()
858 if err != nil {
859 pgConn.asyncClose()
860 return nil, normalizeTimeoutError(ctx, err)
861 }
862
863 switch msg := msg.(type) {
864 case *pgproto3.ParameterDescription:
865 psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs))
866 copy(psd.ParamOIDs, msg.ParameterOIDs)
867 case *pgproto3.RowDescription:
868 psd.Fields = pgConn.convertRowDescription(nil, msg)
869 case *pgproto3.ErrorResponse:
870 parseErr = ErrorResponseToPgError(msg)
871 case *pgproto3.ReadyForQuery:
872 break readloop
873 }
874 }
875
876 if parseErr != nil {
877 return nil, parseErr
878 }
879 return psd, nil
880 }
881
882
883
884
885
886
887
888 func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error {
889 if err := pgConn.lock(); err != nil {
890 return err
891 }
892 defer pgConn.unlock()
893
894 if ctx != context.Background() {
895 select {
896 case <-ctx.Done():
897 return newContextAlreadyDoneError(ctx)
898 default:
899 }
900 pgConn.contextWatcher.Watch(ctx)
901 defer pgConn.contextWatcher.Unwatch()
902 }
903
904 pgConn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
905 pgConn.frontend.SendSync(&pgproto3.Sync{})
906 err := pgConn.flushWithPotentialWriteReadDeadlock()
907 if err != nil {
908 pgConn.asyncClose()
909 return err
910 }
911
912 for {
913 msg, err := pgConn.receiveMessage()
914 if err != nil {
915 pgConn.asyncClose()
916 return normalizeTimeoutError(ctx, err)
917 }
918
919 switch msg := msg.(type) {
920 case *pgproto3.ErrorResponse:
921 return ErrorResponseToPgError(msg)
922 case *pgproto3.ReadyForQuery:
923 return nil
924 }
925 }
926 }
927
928
929 func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
930 return &PgError{
931 Severity: msg.Severity,
932 Code: string(msg.Code),
933 Message: string(msg.Message),
934 Detail: string(msg.Detail),
935 Hint: msg.Hint,
936 Position: msg.Position,
937 InternalPosition: msg.InternalPosition,
938 InternalQuery: string(msg.InternalQuery),
939 Where: string(msg.Where),
940 SchemaName: string(msg.SchemaName),
941 TableName: string(msg.TableName),
942 ColumnName: string(msg.ColumnName),
943 DataTypeName: string(msg.DataTypeName),
944 ConstraintName: msg.ConstraintName,
945 File: string(msg.File),
946 Line: msg.Line,
947 Routine: string(msg.Routine),
948 }
949 }
950
951 func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice {
952 pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg))
953 return (*Notice)(pgerr)
954 }
955
956
957
958
959 func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
960
961
962
963 serverAddr := pgConn.conn.RemoteAddr()
964 var serverNetwork string
965 var serverAddress string
966 if serverAddr.Network() == "unix" {
967
968
969
970 serverNetwork, serverAddress = NetworkAddress(pgConn.config.Host, pgConn.config.Port)
971 } else {
972 serverNetwork, serverAddress = serverAddr.Network(), serverAddr.String()
973 }
974 cancelConn, err := pgConn.config.DialFunc(ctx, serverNetwork, serverAddress)
975 if err != nil {
976
977
978 if serverAddr.Network() != "unix" {
979 return err
980 }
981 serverNetwork, serverAddr := NetworkAddress(pgConn.config.Host, pgConn.config.Port)
982 cancelConn, err = pgConn.config.DialFunc(ctx, serverNetwork, serverAddr)
983 if err != nil {
984 return err
985 }
986 }
987 defer cancelConn.Close()
988
989 if ctx != context.Background() {
990 contextWatcher := ctxwatch.NewContextWatcher(
991 func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
992 func() { cancelConn.SetDeadline(time.Time{}) },
993 )
994 contextWatcher.Watch(ctx)
995 defer contextWatcher.Unwatch()
996 }
997
998 buf := make([]byte, 16)
999 binary.BigEndian.PutUint32(buf[0:4], 16)
1000 binary.BigEndian.PutUint32(buf[4:8], 80877102)
1001 binary.BigEndian.PutUint32(buf[8:12], pgConn.pid)
1002 binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey)
1003
1004 if _, err := cancelConn.Write(buf); err != nil {
1005 return fmt.Errorf("write to connection for cancellation: %w", err)
1006 }
1007
1008
1009
1010 _, _ = cancelConn.Read(buf)
1011
1012 return nil
1013 }
1014
1015
1016
1017 func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
1018 if err := pgConn.lock(); err != nil {
1019 return err
1020 }
1021 defer pgConn.unlock()
1022
1023 if ctx != context.Background() {
1024 select {
1025 case <-ctx.Done():
1026 return newContextAlreadyDoneError(ctx)
1027 default:
1028 }
1029
1030 pgConn.contextWatcher.Watch(ctx)
1031 defer pgConn.contextWatcher.Unwatch()
1032 }
1033
1034 for {
1035 msg, err := pgConn.receiveMessage()
1036 if err != nil {
1037 return normalizeTimeoutError(ctx, err)
1038 }
1039
1040 switch msg.(type) {
1041 case *pgproto3.NotificationResponse:
1042 return nil
1043 }
1044 }
1045 }
1046
1047
1048
1049
1050
1051
1052 func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
1053 if err := pgConn.lock(); err != nil {
1054 return &MultiResultReader{
1055 closed: true,
1056 err: err,
1057 }
1058 }
1059
1060 pgConn.multiResultReader = MultiResultReader{
1061 pgConn: pgConn,
1062 ctx: ctx,
1063 }
1064 multiResult := &pgConn.multiResultReader
1065 if ctx != context.Background() {
1066 select {
1067 case <-ctx.Done():
1068 multiResult.closed = true
1069 multiResult.err = newContextAlreadyDoneError(ctx)
1070 pgConn.unlock()
1071 return multiResult
1072 default:
1073 }
1074 pgConn.contextWatcher.Watch(ctx)
1075 }
1076
1077 pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
1078 err := pgConn.flushWithPotentialWriteReadDeadlock()
1079 if err != nil {
1080 pgConn.asyncClose()
1081 pgConn.contextWatcher.Unwatch()
1082 multiResult.closed = true
1083 multiResult.err = err
1084 pgConn.unlock()
1085 return multiResult
1086 }
1087
1088 return multiResult
1089 }
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110 func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader {
1111 result := pgConn.execExtendedPrefix(ctx, paramValues)
1112 if result.closed {
1113 return result
1114 }
1115
1116 pgConn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
1117 pgConn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
1118
1119 pgConn.execExtendedSuffix(result)
1120
1121 return result
1122 }
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136 func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader {
1137 result := pgConn.execExtendedPrefix(ctx, paramValues)
1138 if result.closed {
1139 return result
1140 }
1141
1142 pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
1143
1144 pgConn.execExtendedSuffix(result)
1145
1146 return result
1147 }
1148
1149 func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader {
1150 pgConn.resultReader = ResultReader{
1151 pgConn: pgConn,
1152 ctx: ctx,
1153 }
1154 result := &pgConn.resultReader
1155
1156 if err := pgConn.lock(); err != nil {
1157 result.concludeCommand(CommandTag{}, err)
1158 result.closed = true
1159 return result
1160 }
1161
1162 if len(paramValues) > math.MaxUint16 {
1163 result.concludeCommand(CommandTag{}, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
1164 result.closed = true
1165 pgConn.unlock()
1166 return result
1167 }
1168
1169 if ctx != context.Background() {
1170 select {
1171 case <-ctx.Done():
1172 result.concludeCommand(CommandTag{}, newContextAlreadyDoneError(ctx))
1173 result.closed = true
1174 pgConn.unlock()
1175 return result
1176 default:
1177 }
1178 pgConn.contextWatcher.Watch(ctx)
1179 }
1180
1181 return result
1182 }
1183
1184 func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
1185 pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
1186 pgConn.frontend.SendExecute(&pgproto3.Execute{})
1187 pgConn.frontend.SendSync(&pgproto3.Sync{})
1188
1189 err := pgConn.flushWithPotentialWriteReadDeadlock()
1190 if err != nil {
1191 pgConn.asyncClose()
1192 result.concludeCommand(CommandTag{}, err)
1193 pgConn.contextWatcher.Unwatch()
1194 result.closed = true
1195 pgConn.unlock()
1196 return
1197 }
1198
1199 result.readUntilRowDescription()
1200 }
1201
1202
1203 func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
1204 if err := pgConn.lock(); err != nil {
1205 return CommandTag{}, err
1206 }
1207
1208 if ctx != context.Background() {
1209 select {
1210 case <-ctx.Done():
1211 pgConn.unlock()
1212 return CommandTag{}, newContextAlreadyDoneError(ctx)
1213 default:
1214 }
1215 pgConn.contextWatcher.Watch(ctx)
1216 defer pgConn.contextWatcher.Unwatch()
1217 }
1218
1219
1220 pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
1221
1222 err := pgConn.flushWithPotentialWriteReadDeadlock()
1223 if err != nil {
1224 pgConn.asyncClose()
1225 pgConn.unlock()
1226 return CommandTag{}, err
1227 }
1228
1229
1230 var commandTag CommandTag
1231 var pgErr error
1232 for {
1233 msg, err := pgConn.receiveMessage()
1234 if err != nil {
1235 pgConn.asyncClose()
1236 return CommandTag{}, normalizeTimeoutError(ctx, err)
1237 }
1238
1239 switch msg := msg.(type) {
1240 case *pgproto3.CopyDone:
1241 case *pgproto3.CopyData:
1242 _, err := w.Write(msg.Data)
1243 if err != nil {
1244 pgConn.asyncClose()
1245 return CommandTag{}, err
1246 }
1247 case *pgproto3.ReadyForQuery:
1248 pgConn.unlock()
1249 return commandTag, pgErr
1250 case *pgproto3.CommandComplete:
1251 commandTag = pgConn.makeCommandTag(msg.CommandTag)
1252 case *pgproto3.ErrorResponse:
1253 pgErr = ErrorResponseToPgError(msg)
1254 }
1255 }
1256 }
1257
1258
1259
1260
1261
1262 func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
1263 if err := pgConn.lock(); err != nil {
1264 return CommandTag{}, err
1265 }
1266 defer pgConn.unlock()
1267
1268 if ctx != context.Background() {
1269 select {
1270 case <-ctx.Done():
1271 return CommandTag{}, newContextAlreadyDoneError(ctx)
1272 default:
1273 }
1274 pgConn.contextWatcher.Watch(ctx)
1275 defer pgConn.contextWatcher.Unwatch()
1276 }
1277
1278
1279 pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
1280 err := pgConn.flushWithPotentialWriteReadDeadlock()
1281 if err != nil {
1282 pgConn.asyncClose()
1283 return CommandTag{}, err
1284 }
1285
1286
1287 abortCopyChan := make(chan struct{})
1288 copyErrChan := make(chan error, 1)
1289 signalMessageChan := pgConn.signalMessage()
1290 var wg sync.WaitGroup
1291 wg.Add(1)
1292
1293 go func() {
1294 defer wg.Done()
1295 buf := iobufpool.Get(65536)
1296 defer iobufpool.Put(buf)
1297 (*buf)[0] = 'd'
1298
1299 for {
1300 n, readErr := r.Read((*buf)[5:cap(*buf)])
1301 if n > 0 {
1302 *buf = (*buf)[0 : n+5]
1303 pgio.SetInt32((*buf)[1:], int32(n+4))
1304
1305 writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
1306 if writeErr != nil {
1307
1308
1309 pgConn.conn.Close()
1310
1311 copyErrChan <- writeErr
1312 return
1313 }
1314 }
1315 if readErr != nil {
1316 copyErrChan <- readErr
1317 return
1318 }
1319
1320 select {
1321 case <-abortCopyChan:
1322 return
1323 default:
1324 }
1325 }
1326 }()
1327
1328 var pgErr error
1329 var copyErr error
1330 for copyErr == nil && pgErr == nil {
1331 select {
1332 case copyErr = <-copyErrChan:
1333 case <-signalMessageChan:
1334
1335
1336
1337 if err := pgConn.bufferingReceiveErr; err != nil {
1338 pgConn.status = connStatusClosed
1339 pgConn.conn.Close()
1340 close(pgConn.cleanupDone)
1341 return CommandTag{}, normalizeTimeoutError(ctx, err)
1342 }
1343 msg, _ := pgConn.receiveMessage()
1344
1345 switch msg := msg.(type) {
1346 case *pgproto3.ErrorResponse:
1347 pgErr = ErrorResponseToPgError(msg)
1348 default:
1349 signalMessageChan = pgConn.signalMessage()
1350 }
1351 }
1352 }
1353 close(abortCopyChan)
1354
1355 wg.Wait()
1356
1357 if copyErr == io.EOF || pgErr != nil {
1358 pgConn.frontend.Send(&pgproto3.CopyDone{})
1359 } else {
1360 pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
1361 }
1362 err = pgConn.flushWithPotentialWriteReadDeadlock()
1363 if err != nil {
1364 pgConn.asyncClose()
1365 return CommandTag{}, err
1366 }
1367
1368
1369 var commandTag CommandTag
1370 for {
1371 msg, err := pgConn.receiveMessage()
1372 if err != nil {
1373 pgConn.asyncClose()
1374 return CommandTag{}, normalizeTimeoutError(ctx, err)
1375 }
1376
1377 switch msg := msg.(type) {
1378 case *pgproto3.ReadyForQuery:
1379 return commandTag, pgErr
1380 case *pgproto3.CommandComplete:
1381 commandTag = pgConn.makeCommandTag(msg.CommandTag)
1382 case *pgproto3.ErrorResponse:
1383 pgErr = ErrorResponseToPgError(msg)
1384 }
1385 }
1386 }
1387
1388
1389 type MultiResultReader struct {
1390 pgConn *PgConn
1391 ctx context.Context
1392 pipeline *Pipeline
1393
1394 rr *ResultReader
1395
1396 closed bool
1397 err error
1398 }
1399
1400
1401 func (mrr *MultiResultReader) ReadAll() ([]*Result, error) {
1402 var results []*Result
1403
1404 for mrr.NextResult() {
1405 results = append(results, mrr.ResultReader().Read())
1406 }
1407 err := mrr.Close()
1408
1409 return results, err
1410 }
1411
1412 func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
1413 msg, err := mrr.pgConn.receiveMessage()
1414 if err != nil {
1415 mrr.pgConn.contextWatcher.Unwatch()
1416 mrr.err = normalizeTimeoutError(mrr.ctx, err)
1417 mrr.closed = true
1418 mrr.pgConn.asyncClose()
1419 return nil, mrr.err
1420 }
1421
1422 switch msg := msg.(type) {
1423 case *pgproto3.ReadyForQuery:
1424 mrr.closed = true
1425 if mrr.pipeline != nil {
1426 mrr.pipeline.expectedReadyForQueryCount--
1427 } else {
1428 mrr.pgConn.contextWatcher.Unwatch()
1429 mrr.pgConn.unlock()
1430 }
1431 case *pgproto3.ErrorResponse:
1432 mrr.err = ErrorResponseToPgError(msg)
1433 }
1434
1435 return msg, nil
1436 }
1437
1438
1439 func (mrr *MultiResultReader) NextResult() bool {
1440 for !mrr.closed && mrr.err == nil {
1441 msg, err := mrr.receiveMessage()
1442 if err != nil {
1443 return false
1444 }
1445
1446 switch msg := msg.(type) {
1447 case *pgproto3.RowDescription:
1448 mrr.pgConn.resultReader = ResultReader{
1449 pgConn: mrr.pgConn,
1450 multiResultReader: mrr,
1451 ctx: mrr.ctx,
1452 fieldDescriptions: mrr.pgConn.convertRowDescription(mrr.pgConn.fieldDescriptions[:], msg),
1453 }
1454
1455 mrr.rr = &mrr.pgConn.resultReader
1456 return true
1457 case *pgproto3.CommandComplete:
1458 mrr.pgConn.resultReader = ResultReader{
1459 commandTag: mrr.pgConn.makeCommandTag(msg.CommandTag),
1460 commandConcluded: true,
1461 closed: true,
1462 }
1463 mrr.rr = &mrr.pgConn.resultReader
1464 return true
1465 case *pgproto3.EmptyQueryResponse:
1466 return false
1467 }
1468 }
1469
1470 return false
1471 }
1472
1473
1474 func (mrr *MultiResultReader) ResultReader() *ResultReader {
1475 return mrr.rr
1476 }
1477
1478
1479 func (mrr *MultiResultReader) Close() error {
1480 for !mrr.closed {
1481 _, err := mrr.receiveMessage()
1482 if err != nil {
1483 return mrr.err
1484 }
1485 }
1486
1487 return mrr.err
1488 }
1489
1490
1491 type ResultReader struct {
1492 pgConn *PgConn
1493 multiResultReader *MultiResultReader
1494 pipeline *Pipeline
1495 ctx context.Context
1496
1497 fieldDescriptions []FieldDescription
1498 rowValues [][]byte
1499 commandTag CommandTag
1500 commandConcluded bool
1501 closed bool
1502 err error
1503 }
1504
1505
1506 type Result struct {
1507 FieldDescriptions []FieldDescription
1508 Rows [][][]byte
1509 CommandTag CommandTag
1510 Err error
1511 }
1512
1513
1514 func (rr *ResultReader) Read() *Result {
1515 br := &Result{}
1516
1517 for rr.NextRow() {
1518 if br.FieldDescriptions == nil {
1519 br.FieldDescriptions = make([]FieldDescription, len(rr.FieldDescriptions()))
1520 copy(br.FieldDescriptions, rr.FieldDescriptions())
1521 }
1522
1523 values := rr.Values()
1524 row := make([][]byte, len(values))
1525 for i := range row {
1526 row[i] = make([]byte, len(values[i]))
1527 copy(row[i], values[i])
1528 }
1529 br.Rows = append(br.Rows, row)
1530 }
1531
1532 br.CommandTag, br.Err = rr.Close()
1533
1534 return br
1535 }
1536
1537
1538 func (rr *ResultReader) NextRow() bool {
1539 for !rr.commandConcluded {
1540 msg, err := rr.receiveMessage()
1541 if err != nil {
1542 return false
1543 }
1544
1545 switch msg := msg.(type) {
1546 case *pgproto3.DataRow:
1547 rr.rowValues = msg.Values
1548 return true
1549 }
1550 }
1551
1552 return false
1553 }
1554
1555
1556
1557
1558 func (rr *ResultReader) FieldDescriptions() []FieldDescription {
1559 return rr.fieldDescriptions
1560 }
1561
1562
1563
1564 func (rr *ResultReader) Values() [][]byte {
1565 return rr.rowValues
1566 }
1567
1568
1569
1570 func (rr *ResultReader) Close() (CommandTag, error) {
1571 if rr.closed {
1572 return rr.commandTag, rr.err
1573 }
1574 rr.closed = true
1575
1576 for !rr.commandConcluded {
1577 _, err := rr.receiveMessage()
1578 if err != nil {
1579 return CommandTag{}, rr.err
1580 }
1581 }
1582
1583 if rr.multiResultReader == nil && rr.pipeline == nil {
1584 for {
1585 msg, err := rr.receiveMessage()
1586 if err != nil {
1587 return CommandTag{}, rr.err
1588 }
1589
1590 switch msg := msg.(type) {
1591
1592 case *pgproto3.ErrorResponse:
1593 rr.err = ErrorResponseToPgError(msg)
1594 case *pgproto3.ReadyForQuery:
1595 rr.pgConn.contextWatcher.Unwatch()
1596 rr.pgConn.unlock()
1597 return rr.commandTag, rr.err
1598 }
1599 }
1600 }
1601
1602 return rr.commandTag, rr.err
1603 }
1604
1605
1606
1607 func (rr *ResultReader) readUntilRowDescription() {
1608 for !rr.commandConcluded {
1609
1610
1611
1612 msg, _ := rr.pgConn.peekMessage()
1613 if _, ok := msg.(*pgproto3.DataRow); ok {
1614 return
1615 }
1616
1617
1618 msg, _ = rr.receiveMessage()
1619 if _, ok := msg.(*pgproto3.RowDescription); ok {
1620 return
1621 }
1622 }
1623 }
1624
1625 func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
1626 if rr.multiResultReader == nil {
1627 msg, err = rr.pgConn.receiveMessage()
1628 } else {
1629 msg, err = rr.multiResultReader.receiveMessage()
1630 }
1631
1632 if err != nil {
1633 err = normalizeTimeoutError(rr.ctx, err)
1634 rr.concludeCommand(CommandTag{}, err)
1635 rr.pgConn.contextWatcher.Unwatch()
1636 rr.closed = true
1637 if rr.multiResultReader == nil {
1638 rr.pgConn.asyncClose()
1639 }
1640
1641 return nil, rr.err
1642 }
1643
1644 switch msg := msg.(type) {
1645 case *pgproto3.RowDescription:
1646 rr.fieldDescriptions = rr.pgConn.convertRowDescription(rr.pgConn.fieldDescriptions[:], msg)
1647 case *pgproto3.CommandComplete:
1648 rr.concludeCommand(rr.pgConn.makeCommandTag(msg.CommandTag), nil)
1649 case *pgproto3.EmptyQueryResponse:
1650 rr.concludeCommand(CommandTag{}, nil)
1651 case *pgproto3.ErrorResponse:
1652 rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg))
1653 }
1654
1655 return msg, nil
1656 }
1657
1658 func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
1659
1660
1661 if err != nil && rr.err == nil {
1662 rr.err = err
1663 }
1664
1665 if rr.commandConcluded {
1666 return
1667 }
1668
1669 rr.commandTag = commandTag
1670 rr.rowValues = nil
1671 rr.commandConcluded = true
1672 }
1673
1674
1675 type Batch struct {
1676 buf []byte
1677 err error
1678 }
1679
1680
1681 func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
1682 if batch.err != nil {
1683 return
1684 }
1685
1686 batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
1687 if batch.err != nil {
1688 return
1689 }
1690 batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
1691 }
1692
1693
1694 func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
1695 if batch.err != nil {
1696 return
1697 }
1698
1699 batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
1700 if batch.err != nil {
1701 return
1702 }
1703
1704 batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
1705 if batch.err != nil {
1706 return
1707 }
1708
1709 batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
1710 if batch.err != nil {
1711 return
1712 }
1713 }
1714
1715
1716
1717
1718 func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
1719 if batch.err != nil {
1720 return &MultiResultReader{
1721 closed: true,
1722 err: batch.err,
1723 }
1724 }
1725
1726 if err := pgConn.lock(); err != nil {
1727 return &MultiResultReader{
1728 closed: true,
1729 err: err,
1730 }
1731 }
1732
1733 pgConn.multiResultReader = MultiResultReader{
1734 pgConn: pgConn,
1735 ctx: ctx,
1736 }
1737 multiResult := &pgConn.multiResultReader
1738
1739 if ctx != context.Background() {
1740 select {
1741 case <-ctx.Done():
1742 multiResult.closed = true
1743 multiResult.err = newContextAlreadyDoneError(ctx)
1744 pgConn.unlock()
1745 return multiResult
1746 default:
1747 }
1748 pgConn.contextWatcher.Watch(ctx)
1749 }
1750
1751 batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
1752 if batch.err != nil {
1753 multiResult.closed = true
1754 multiResult.err = batch.err
1755 pgConn.unlock()
1756 return multiResult
1757 }
1758
1759 pgConn.enterPotentialWriteReadDeadlock()
1760 defer pgConn.exitPotentialWriteReadDeadlock()
1761 _, err := pgConn.conn.Write(batch.buf)
1762 if err != nil {
1763 multiResult.closed = true
1764 multiResult.err = err
1765 pgConn.unlock()
1766 return multiResult
1767 }
1768
1769 return multiResult
1770 }
1771
1772
1773
1774
1775
1776
1777 func (pgConn *PgConn) EscapeString(s string) (string, error) {
1778 if pgConn.ParameterStatus("standard_conforming_strings") != "on" {
1779 return "", errors.New("EscapeString must be run with standard_conforming_strings=on")
1780 }
1781
1782 if pgConn.ParameterStatus("client_encoding") != "UTF8" {
1783 return "", errors.New("EscapeString must be run with client_encoding=UTF8")
1784 }
1785
1786 return strings.Replace(s, "'", "''", -1), nil
1787 }
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797 func (pgConn *PgConn) CheckConn() error {
1798 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
1799 defer cancel()
1800
1801 _, err := pgConn.ReceiveMessage(ctx)
1802 if err != nil {
1803 if !Timeout(err) {
1804 return err
1805 }
1806 }
1807
1808 return nil
1809 }
1810
1811
1812
1813
1814 func (pgConn *PgConn) Ping(ctx context.Context) error {
1815 return pgConn.Exec(ctx, "-- ping").Close()
1816 }
1817
1818
1819 func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
1820 return CommandTag{s: string(buf)}
1821 }
1822
1823
1824
1825 func (pgConn *PgConn) enterPotentialWriteReadDeadlock() {
1826
1827
1828
1829
1830
1831
1832 if pgConn.slowWriteTimer.Reset(15 * time.Millisecond) {
1833 panic("BUG: slow write timer already active")
1834 }
1835 }
1836
1837
1838 func (pgConn *PgConn) exitPotentialWriteReadDeadlock() {
1839 if !pgConn.slowWriteTimer.Stop() {
1840
1841
1842
1843
1844
1845
1846 <-pgConn.bgReaderStarted
1847 pgConn.bgReader.Stop()
1848 }
1849 }
1850
1851 func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error {
1852 pgConn.enterPotentialWriteReadDeadlock()
1853 defer pgConn.exitPotentialWriteReadDeadlock()
1854 err := pgConn.frontend.Flush()
1855 return err
1856 }
1857
1858
1859
1860
1861
1862
1863
1864
1865 func (pgConn *PgConn) SyncConn(ctx context.Context) error {
1866 for i := 0; i < 10; i++ {
1867 if pgConn.bgReader.Status() == bgreader.StatusStopped && pgConn.frontend.ReadBufferLen() == 0 {
1868 return nil
1869 }
1870
1871 err := pgConn.Ping(ctx)
1872 if err != nil {
1873 return fmt.Errorf("SyncConn: Ping failed while syncing conn: %w", err)
1874 }
1875 }
1876
1877
1878
1879 return errors.New("SyncConn: conn never synchronized")
1880 }
1881
1882
1883
1884
1885
1886 type HijackedConn struct {
1887 Conn net.Conn
1888 PID uint32
1889 SecretKey uint32
1890 ParameterStatuses map[string]string
1891 TxStatus byte
1892 Frontend *pgproto3.Frontend
1893 Config *Config
1894 }
1895
1896
1897
1898
1899
1900
1901
1902 func (pgConn *PgConn) Hijack() (*HijackedConn, error) {
1903 if err := pgConn.lock(); err != nil {
1904 return nil, err
1905 }
1906 pgConn.status = connStatusClosed
1907
1908 return &HijackedConn{
1909 Conn: pgConn.conn,
1910 PID: pgConn.pid,
1911 SecretKey: pgConn.secretKey,
1912 ParameterStatuses: pgConn.parameterStatuses,
1913 TxStatus: pgConn.txStatus,
1914 Frontend: pgConn.frontend,
1915 Config: pgConn.config,
1916 }, nil
1917 }
1918
1919
1920
1921
1922
1923
1924
1925
1926 func Construct(hc *HijackedConn) (*PgConn, error) {
1927 pgConn := &PgConn{
1928 conn: hc.Conn,
1929 pid: hc.PID,
1930 secretKey: hc.SecretKey,
1931 parameterStatuses: hc.ParameterStatuses,
1932 txStatus: hc.TxStatus,
1933 frontend: hc.Frontend,
1934 config: hc.Config,
1935
1936 status: connStatusIdle,
1937
1938 cleanupDone: make(chan struct{}),
1939 }
1940
1941 pgConn.contextWatcher = newContextWatcher(pgConn.conn)
1942 pgConn.bgReader = bgreader.New(pgConn.conn)
1943 pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
1944 func() {
1945 pgConn.bgReader.Start()
1946 pgConn.bgReaderStarted <- struct{}{}
1947 },
1948 )
1949 pgConn.slowWriteTimer.Stop()
1950 pgConn.bgReaderStarted = make(chan struct{})
1951 pgConn.frontend = hc.Config.BuildFrontend(pgConn.bgReader, pgConn.conn)
1952
1953 return pgConn, nil
1954 }
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967 type Pipeline struct {
1968 conn *PgConn
1969 ctx context.Context
1970
1971 expectedReadyForQueryCount int
1972 pendingSync bool
1973
1974 err error
1975 closed bool
1976 }
1977
1978
1979 type PipelineSync struct{}
1980
1981
1982 type CloseComplete struct{}
1983
1984
1985
1986
1987
1988
1989
1990 func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline {
1991 if err := pgConn.lock(); err != nil {
1992 return &Pipeline{
1993 closed: true,
1994 err: err,
1995 }
1996 }
1997
1998 pgConn.pipeline = Pipeline{
1999 conn: pgConn,
2000 ctx: ctx,
2001 }
2002 pipeline := &pgConn.pipeline
2003
2004 if ctx != context.Background() {
2005 select {
2006 case <-ctx.Done():
2007 pipeline.closed = true
2008 pipeline.err = newContextAlreadyDoneError(ctx)
2009 pgConn.unlock()
2010 return pipeline
2011 default:
2012 }
2013 pgConn.contextWatcher.Watch(ctx)
2014 }
2015
2016 return pipeline
2017 }
2018
2019
2020 func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) {
2021 if p.closed {
2022 return
2023 }
2024 p.pendingSync = true
2025
2026 p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
2027 p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
2028 }
2029
2030
2031 func (p *Pipeline) SendDeallocate(name string) {
2032 if p.closed {
2033 return
2034 }
2035 p.pendingSync = true
2036
2037 p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
2038 }
2039
2040
2041 func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
2042 if p.closed {
2043 return
2044 }
2045 p.pendingSync = true
2046
2047 p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
2048 p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
2049 p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
2050 p.conn.frontend.SendExecute(&pgproto3.Execute{})
2051 }
2052
2053
2054 func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
2055 if p.closed {
2056 return
2057 }
2058 p.pendingSync = true
2059
2060 p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
2061 p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
2062 p.conn.frontend.SendExecute(&pgproto3.Execute{})
2063 }
2064
2065
2066 func (p *Pipeline) Flush() error {
2067 if p.closed {
2068 if p.err != nil {
2069 return p.err
2070 }
2071 return errors.New("pipeline closed")
2072 }
2073
2074 err := p.conn.flushWithPotentialWriteReadDeadlock()
2075 if err != nil {
2076 err = normalizeTimeoutError(p.ctx, err)
2077
2078 p.conn.asyncClose()
2079
2080 p.conn.contextWatcher.Unwatch()
2081 p.conn.unlock()
2082 p.closed = true
2083 p.err = err
2084 return err
2085 }
2086
2087 return nil
2088 }
2089
2090
2091 func (p *Pipeline) Sync() error {
2092 if p.closed {
2093 if p.err != nil {
2094 return p.err
2095 }
2096 return errors.New("pipeline closed")
2097 }
2098
2099 p.conn.frontend.SendSync(&pgproto3.Sync{})
2100 err := p.Flush()
2101 if err != nil {
2102 return err
2103 }
2104
2105 p.pendingSync = false
2106 p.expectedReadyForQueryCount++
2107
2108 return nil
2109 }
2110
2111
2112
2113
2114 func (p *Pipeline) GetResults() (results any, err error) {
2115 if p.closed {
2116 if p.err != nil {
2117 return nil, p.err
2118 }
2119 return nil, errors.New("pipeline closed")
2120 }
2121
2122 if p.expectedReadyForQueryCount == 0 {
2123 return nil, nil
2124 }
2125
2126 return p.getResults()
2127 }
2128
2129 func (p *Pipeline) getResults() (results any, err error) {
2130 for {
2131 msg, err := p.conn.receiveMessage()
2132 if err != nil {
2133 p.closed = true
2134 p.err = err
2135 p.conn.asyncClose()
2136 return nil, normalizeTimeoutError(p.ctx, err)
2137 }
2138
2139 switch msg := msg.(type) {
2140 case *pgproto3.RowDescription:
2141 p.conn.resultReader = ResultReader{
2142 pgConn: p.conn,
2143 pipeline: p,
2144 ctx: p.ctx,
2145 fieldDescriptions: p.conn.convertRowDescription(p.conn.fieldDescriptions[:], msg),
2146 }
2147 return &p.conn.resultReader, nil
2148 case *pgproto3.CommandComplete:
2149 p.conn.resultReader = ResultReader{
2150 commandTag: p.conn.makeCommandTag(msg.CommandTag),
2151 commandConcluded: true,
2152 closed: true,
2153 }
2154 return &p.conn.resultReader, nil
2155 case *pgproto3.ParseComplete:
2156 peekedMsg, err := p.conn.peekMessage()
2157 if err != nil {
2158 p.conn.asyncClose()
2159 return nil, normalizeTimeoutError(p.ctx, err)
2160 }
2161 if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok {
2162 return p.getResultsPrepare()
2163 }
2164 case *pgproto3.CloseComplete:
2165 return &CloseComplete{}, nil
2166 case *pgproto3.ReadyForQuery:
2167 p.expectedReadyForQueryCount--
2168 return &PipelineSync{}, nil
2169 case *pgproto3.ErrorResponse:
2170 pgErr := ErrorResponseToPgError(msg)
2171 return nil, pgErr
2172 }
2173
2174 }
2175 }
2176
2177 func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
2178 psd := &StatementDescription{}
2179
2180 for {
2181 msg, err := p.conn.receiveMessage()
2182 if err != nil {
2183 p.conn.asyncClose()
2184 return nil, normalizeTimeoutError(p.ctx, err)
2185 }
2186
2187 switch msg := msg.(type) {
2188 case *pgproto3.ParameterDescription:
2189 psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs))
2190 copy(psd.ParamOIDs, msg.ParameterOIDs)
2191 case *pgproto3.RowDescription:
2192 psd.Fields = p.conn.convertRowDescription(nil, msg)
2193 return psd, nil
2194
2195
2196
2197 case *pgproto3.NoData:
2198 return psd, nil
2199
2200
2201 case *pgproto3.ErrorResponse:
2202 pgErr := ErrorResponseToPgError(msg)
2203 return nil, pgErr
2204 case *pgproto3.CommandComplete:
2205 p.conn.asyncClose()
2206 return nil, errors.New("BUG: received CommandComplete while handling Describe")
2207 case *pgproto3.ReadyForQuery:
2208 p.conn.asyncClose()
2209 return nil, errors.New("BUG: received ReadyForQuery while handling Describe")
2210 }
2211 }
2212 }
2213
2214
2215 func (p *Pipeline) Close() error {
2216 if p.closed {
2217 return p.err
2218 }
2219
2220 p.closed = true
2221
2222 if p.pendingSync {
2223 p.conn.asyncClose()
2224 p.err = errors.New("pipeline has unsynced requests")
2225 p.conn.contextWatcher.Unwatch()
2226 p.conn.unlock()
2227
2228 return p.err
2229 }
2230
2231 for p.expectedReadyForQueryCount > 0 {
2232 _, err := p.getResults()
2233 if err != nil {
2234 p.err = err
2235 var pgErr *PgError
2236 if !errors.As(err, &pgErr) {
2237 p.conn.asyncClose()
2238 break
2239 }
2240 }
2241 }
2242
2243 p.conn.contextWatcher.Unwatch()
2244 p.conn.unlock()
2245
2246 return p.err
2247 }
2248
View as plain text