1 package pgconn
2
3 import (
4 "context"
5 "crypto/md5"
6 "crypto/tls"
7 "encoding/binary"
8 "encoding/hex"
9 "errors"
10 "fmt"
11 "io"
12 "math"
13 "net"
14 "strconv"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/jackc/pgconn/internal/ctxwatch"
20 "github.com/jackc/pgio"
21 "github.com/jackc/pgproto3/v2"
22 )
23
24 const (
25 connStatusUninitialized = iota
26 connStatusConnecting
27 connStatusClosed
28 connStatusIdle
29 connStatusBusy
30 )
31
32 const wbufLen = 1024
33
34
35
36 type Notice PgError
37
38
39 type Notification struct {
40 PID uint32
41 Channel string
42 Payload string
43 }
44
45
46 type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
47
48
49
50 type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
51
52
53 type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend
54
55
56
57
58
59 type NoticeHandler func(*PgConn, *Notice)
60
61
62
63
64
65 type NotificationHandler func(*PgConn, *Notification)
66
67
68 type Frontend interface {
69 Receive() (pgproto3.BackendMessage, error)
70 }
71
72
73 type PgConn struct {
74 conn net.Conn
75 pid uint32
76 secretKey uint32
77 parameterStatuses map[string]string
78 txStatus byte
79 frontend Frontend
80
81 config *Config
82
83 status byte
84
85 bufferingReceive bool
86 bufferingReceiveMux sync.Mutex
87 bufferingReceiveMsg pgproto3.BackendMessage
88 bufferingReceiveErr error
89
90 peekedMsg pgproto3.BackendMessage
91
92
93 wbuf []byte
94 resultReader ResultReader
95 multiResultReader MultiResultReader
96 contextWatcher *ctxwatch.ContextWatcher
97
98 cleanupDone chan struct{}
99 }
100
101
102
103 func Connect(ctx context.Context, connString string) (*PgConn, error) {
104 config, err := ParseConfig(connString)
105 if err != nil {
106 return nil, err
107 }
108
109 return ConnectConfig(ctx, config)
110 }
111
112
113
114
115 func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) {
116 config, err := ParseConfigWithOptions(connString, parseConfigOptions)
117 if err != nil {
118 return nil, err
119 }
120
121 return ConnectConfig(ctx, config)
122 }
123
124
125
126
127
128
129
130
131 func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) {
132
133
134 if !config.createdByParseConfig {
135 panic("config must be created by ParseConfig")
136 }
137
138
139 fallbackConfigs := []*FallbackConfig{
140 {
141 Host: config.Host,
142 Port: config.Port,
143 TLSConfig: config.TLSConfig,
144 },
145 }
146 fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
147 ctx := octx
148 fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
149 if err != nil {
150 return nil, &connectError{config: config, msg: "hostname resolving error", err: err}
151 }
152
153 if len(fallbackConfigs) == 0 {
154 return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
155 }
156
157 foundBestServer := false
158 var fallbackConfig *FallbackConfig
159 for i, fc := range fallbackConfigs {
160
161 if config.ConnectTimeout != 0 {
162
163 if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) {
164 var cancel context.CancelFunc
165 ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout)
166 defer cancel()
167 }
168 } else {
169 ctx = octx
170 }
171 pgConn, err = connect(ctx, config, fc, false)
172 if err == nil {
173 foundBestServer = true
174 break
175 } else if pgerr, ok := err.(*PgError); ok {
176 err = &connectError{config: config, msg: "server error", err: pgerr}
177 const ERRCODE_INVALID_PASSWORD = "28P01"
178 const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000"
179 const ERRCODE_INVALID_CATALOG_NAME = "3D000"
180 const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501"
181 if pgerr.Code == ERRCODE_INVALID_PASSWORD ||
182 pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil ||
183 pgerr.Code == ERRCODE_INVALID_CATALOG_NAME ||
184 pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
185 break
186 }
187 } else if cerr, ok := err.(*connectError); ok {
188 if _, ok := cerr.err.(*NotPreferredError); ok {
189 fallbackConfig = fc
190 }
191 }
192 }
193
194 if !foundBestServer && fallbackConfig != nil {
195 pgConn, err = connect(ctx, config, fallbackConfig, true)
196 if pgerr, ok := err.(*PgError); ok {
197 err = &connectError{config: config, msg: "server error", err: pgerr}
198 }
199 }
200
201 if err != nil {
202 return nil, err
203 }
204
205 if config.AfterConnect != nil {
206 err := config.AfterConnect(ctx, pgConn)
207 if err != nil {
208 pgConn.conn.Close()
209 return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
210 }
211 }
212
213 return pgConn, nil
214 }
215
216 func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) {
217 var configs []*FallbackConfig
218
219 for _, fb := range fallbacks {
220
221 if isAbsolutePath(fb.Host) {
222 configs = append(configs, &FallbackConfig{
223 Host: fb.Host,
224 Port: fb.Port,
225 TLSConfig: fb.TLSConfig,
226 })
227
228 continue
229 }
230
231 ips, err := lookupFn(ctx, fb.Host)
232 if err != nil {
233 return nil, err
234 }
235
236 for _, ip := range ips {
237 splitIP, splitPort, err := net.SplitHostPort(ip)
238 if err == nil {
239 port, err := strconv.ParseUint(splitPort, 10, 16)
240 if err != nil {
241 return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)
242 }
243 configs = append(configs, &FallbackConfig{
244 Host: splitIP,
245 Port: uint16(port),
246 TLSConfig: fb.TLSConfig,
247 })
248 } else {
249 configs = append(configs, &FallbackConfig{
250 Host: ip,
251 Port: fb.Port,
252 TLSConfig: fb.TLSConfig,
253 })
254 }
255 }
256 }
257
258 return configs, nil
259 }
260
261 func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig,
262 ignoreNotPreferredErr bool) (*PgConn, error) {
263 pgConn := new(PgConn)
264 pgConn.config = config
265 pgConn.wbuf = make([]byte, 0, wbufLen)
266 pgConn.cleanupDone = make(chan struct{})
267
268 var err error
269 network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
270 netConn, err := config.DialFunc(ctx, network, address)
271 if err != nil {
272 var netErr net.Error
273 if errors.As(err, &netErr) && netErr.Timeout() {
274 err = &errTimeout{err: err}
275 }
276 return nil, &connectError{config: config, msg: "dial error", err: err}
277 }
278
279 pgConn.conn = netConn
280 pgConn.contextWatcher = newContextWatcher(netConn)
281 pgConn.contextWatcher.Watch(ctx)
282
283 if fallbackConfig.TLSConfig != nil {
284 tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
285 pgConn.contextWatcher.Unwatch()
286 if err != nil {
287 netConn.Close()
288 return nil, &connectError{config: config, msg: "tls error", err: err}
289 }
290
291 pgConn.conn = tlsConn
292 pgConn.contextWatcher = newContextWatcher(tlsConn)
293 pgConn.contextWatcher.Watch(ctx)
294 }
295
296 defer pgConn.contextWatcher.Unwatch()
297
298 pgConn.parameterStatuses = make(map[string]string)
299 pgConn.status = connStatusConnecting
300 pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
301
302 startupMsg := pgproto3.StartupMessage{
303 ProtocolVersion: pgproto3.ProtocolVersionNumber,
304 Parameters: make(map[string]string),
305 }
306
307
308 for k, v := range config.RuntimeParams {
309 startupMsg.Parameters[k] = v
310 }
311
312 startupMsg.Parameters["user"] = config.User
313 if config.Database != "" {
314 startupMsg.Parameters["database"] = config.Database
315 }
316
317 buf, err := startupMsg.Encode(pgConn.wbuf)
318 if err != nil {
319 return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
320 }
321 if _, err := pgConn.conn.Write(buf); err != nil {
322 pgConn.conn.Close()
323 return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
324 }
325
326 for {
327 msg, err := pgConn.receiveMessage()
328 if err != nil {
329 pgConn.conn.Close()
330 if err, ok := err.(*PgError); ok {
331 return nil, err
332 }
333 return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)}
334 }
335
336 switch msg := msg.(type) {
337 case *pgproto3.BackendKeyData:
338 pgConn.pid = msg.ProcessID
339 pgConn.secretKey = msg.SecretKey
340
341 case *pgproto3.AuthenticationOk:
342 case *pgproto3.AuthenticationCleartextPassword:
343 err = pgConn.txPasswordMessage(pgConn.config.Password)
344 if err != nil {
345 pgConn.conn.Close()
346 return nil, &connectError{config: config, msg: "failed to write password message", err: err}
347 }
348 case *pgproto3.AuthenticationMD5Password:
349 digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
350 err = pgConn.txPasswordMessage(digestedPassword)
351 if err != nil {
352 pgConn.conn.Close()
353 return nil, &connectError{config: config, msg: "failed to write password message", err: err}
354 }
355 case *pgproto3.AuthenticationSASL:
356 err = pgConn.scramAuth(msg.AuthMechanisms)
357 if err != nil {
358 pgConn.conn.Close()
359 return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
360 }
361 case *pgproto3.AuthenticationGSS:
362 err = pgConn.gssAuth()
363 if err != nil {
364 pgConn.conn.Close()
365 return nil, &connectError{config: config, msg: "failed GSS auth", err: err}
366 }
367 case *pgproto3.ReadyForQuery:
368 pgConn.status = connStatusIdle
369 if config.ValidateConnect != nil {
370
371
372
373
374
375 pgConn.contextWatcher.Unwatch()
376
377 err := config.ValidateConnect(ctx, pgConn)
378 if err != nil {
379 if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok {
380 return pgConn, nil
381 }
382 pgConn.conn.Close()
383 return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
384 }
385 }
386 return pgConn, nil
387 case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse:
388
389 case *pgproto3.ErrorResponse:
390 pgConn.conn.Close()
391 return nil, ErrorResponseToPgError(msg)
392 default:
393 pgConn.conn.Close()
394 return nil, &connectError{config: config, msg: "received unexpected message", err: err}
395 }
396 }
397 }
398
399 func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
400 return ctxwatch.NewContextWatcher(
401 func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
402 func() { conn.SetDeadline(time.Time{}) },
403 )
404 }
405
406 func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
407 err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
408 if err != nil {
409 return nil, err
410 }
411
412 response := make([]byte, 1)
413 if _, err = io.ReadFull(conn, response); err != nil {
414 return nil, err
415 }
416
417 if response[0] != 'S' {
418 return nil, errors.New("server refused TLS connection")
419 }
420
421 return tls.Client(conn, tlsConfig), nil
422 }
423
424 func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
425 msg := &pgproto3.PasswordMessage{Password: password}
426 buf, err := msg.Encode(pgConn.wbuf)
427 if err != nil {
428 return err
429 }
430 _, err = pgConn.conn.Write(buf)
431 return err
432 }
433
434 func hexMD5(s string) string {
435 hash := md5.New()
436 io.WriteString(hash, s)
437 return hex.EncodeToString(hash.Sum(nil))
438 }
439
440 func (pgConn *PgConn) signalMessage() chan struct{} {
441 if pgConn.bufferingReceive {
442 panic("BUG: signalMessage when already in progress")
443 }
444
445 pgConn.bufferingReceive = true
446 pgConn.bufferingReceiveMux.Lock()
447
448 ch := make(chan struct{})
449 go func() {
450 pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
451 pgConn.bufferingReceiveMux.Unlock()
452 close(ch)
453 }()
454
455 return ch
456 }
457
458
459
460
461
462
463 func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
464 if err := pgConn.lock(); err != nil {
465 return err
466 }
467 defer pgConn.unlock()
468
469 if ctx != context.Background() {
470 select {
471 case <-ctx.Done():
472 return newContextAlreadyDoneError(ctx)
473 default:
474 }
475 pgConn.contextWatcher.Watch(ctx)
476 defer pgConn.contextWatcher.Unwatch()
477 }
478
479 n, err := pgConn.conn.Write(buf)
480 if err != nil {
481 pgConn.asyncClose()
482 return &writeError{err: err, safeToRetry: n == 0}
483 }
484
485 return nil
486 }
487
488
489
490
491
492
493
494
495 func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
496 if err := pgConn.lock(); err != nil {
497 return nil, err
498 }
499 defer pgConn.unlock()
500
501 if ctx != context.Background() {
502 select {
503 case <-ctx.Done():
504 return nil, newContextAlreadyDoneError(ctx)
505 default:
506 }
507 pgConn.contextWatcher.Watch(ctx)
508 defer pgConn.contextWatcher.Unwatch()
509 }
510
511 msg, err := pgConn.receiveMessage()
512 if err != nil {
513 err = &pgconnError{
514 msg: "receive message failed",
515 err: preferContextOverNetTimeoutError(ctx, err),
516 safeToRetry: true}
517 }
518 return msg, err
519 }
520
521
522 func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
523 if pgConn.peekedMsg != nil {
524 return pgConn.peekedMsg, nil
525 }
526
527 var msg pgproto3.BackendMessage
528 var err error
529 if pgConn.bufferingReceive {
530 pgConn.bufferingReceiveMux.Lock()
531 msg = pgConn.bufferingReceiveMsg
532 err = pgConn.bufferingReceiveErr
533 pgConn.bufferingReceiveMux.Unlock()
534 pgConn.bufferingReceive = false
535
536
537 var netErr net.Error
538 if errors.As(err, &netErr) && netErr.Timeout() {
539 msg, err = pgConn.frontend.Receive()
540 }
541 } else {
542 msg, err = pgConn.frontend.Receive()
543 }
544
545 if err != nil {
546
547 var netErr net.Error
548 isNetErr := errors.As(err, &netErr)
549 if !(isNetErr && netErr.Timeout()) {
550 pgConn.asyncClose()
551 }
552
553 return nil, err
554 }
555
556 pgConn.peekedMsg = msg
557 return msg, nil
558 }
559
560
561 func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
562 msg, err := pgConn.peekMessage()
563 if err != nil {
564
565 var netErr net.Error
566 isNetErr := errors.As(err, &netErr)
567 if !(isNetErr && netErr.Timeout()) {
568 pgConn.asyncClose()
569 }
570
571 return nil, err
572 }
573 pgConn.peekedMsg = nil
574
575 switch msg := msg.(type) {
576 case *pgproto3.ReadyForQuery:
577 pgConn.txStatus = msg.TxStatus
578 case *pgproto3.ParameterStatus:
579 pgConn.parameterStatuses[msg.Name] = msg.Value
580 case *pgproto3.ErrorResponse:
581 if msg.Severity == "FATAL" {
582 pgConn.status = connStatusClosed
583 pgConn.conn.Close()
584 close(pgConn.cleanupDone)
585 return nil, ErrorResponseToPgError(msg)
586 }
587 case *pgproto3.NoticeResponse:
588 if pgConn.config.OnNotice != nil {
589 pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg))
590 }
591 case *pgproto3.NotificationResponse:
592 if pgConn.config.OnNotification != nil {
593 pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload})
594 }
595 }
596
597 return msg, nil
598 }
599
600
601 func (pgConn *PgConn) Conn() net.Conn {
602 return pgConn.conn
603 }
604
605
606 func (pgConn *PgConn) PID() uint32 {
607 return pgConn.pid
608 }
609
610
611
612
613
614
615
616
617
618
619 func (pgConn *PgConn) TxStatus() byte {
620 return pgConn.txStatus
621 }
622
623
624 func (pgConn *PgConn) SecretKey() uint32 {
625 return pgConn.secretKey
626 }
627
628
629
630
631 func (pgConn *PgConn) Close(ctx context.Context) error {
632 if pgConn.status == connStatusClosed {
633 return nil
634 }
635 pgConn.status = connStatusClosed
636
637 defer close(pgConn.cleanupDone)
638 defer pgConn.conn.Close()
639
640 if ctx != context.Background() {
641
642
643
644
645
646 pgConn.contextWatcher.Unwatch()
647
648 pgConn.contextWatcher.Watch(ctx)
649 defer pgConn.contextWatcher.Unwatch()
650 }
651
652
653
654
655
656
657 pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
658
659 return pgConn.conn.Close()
660 }
661
662
663
664 func (pgConn *PgConn) asyncClose() {
665 if pgConn.status == connStatusClosed {
666 return
667 }
668 pgConn.status = connStatusClosed
669
670 go func() {
671 defer close(pgConn.cleanupDone)
672 defer pgConn.conn.Close()
673
674 deadline := time.Now().Add(time.Second * 15)
675
676 ctx, cancel := context.WithDeadline(context.Background(), deadline)
677 defer cancel()
678
679 pgConn.CancelRequest(ctx)
680
681 pgConn.conn.SetDeadline(deadline)
682
683 pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
684 }()
685 }
686
687
688
689
690
691
692
693
694
695 func (pgConn *PgConn) CleanupDone() chan (struct{}) {
696 return pgConn.cleanupDone
697 }
698
699
700
701
702 func (pgConn *PgConn) IsClosed() bool {
703 return pgConn.status < connStatusIdle
704 }
705
706
707 func (pgConn *PgConn) IsBusy() bool {
708 return pgConn.status == connStatusBusy
709 }
710
711
712 func (pgConn *PgConn) lock() error {
713 switch pgConn.status {
714 case connStatusBusy:
715 return &connLockError{status: "conn busy"}
716 case connStatusClosed:
717 return &connLockError{status: "conn closed"}
718 case connStatusUninitialized:
719 return &connLockError{status: "conn uninitialized"}
720 }
721 pgConn.status = connStatusBusy
722 return nil
723 }
724
725 func (pgConn *PgConn) unlock() {
726 switch pgConn.status {
727 case connStatusBusy:
728 pgConn.status = connStatusIdle
729 case connStatusClosed:
730 default:
731 panic("BUG: cannot unlock unlocked connection")
732 }
733 }
734
735
736
737 func (pgConn *PgConn) ParameterStatus(key string) string {
738 return pgConn.parameterStatuses[key]
739 }
740
741
742 type CommandTag []byte
743
744
745
746 func (ct CommandTag) RowsAffected() int64 {
747
748 idx := -1
749 for i := len(ct) - 1; i >= 0; i-- {
750 if ct[i] >= '0' && ct[i] <= '9' {
751 idx = i
752 } else {
753 break
754 }
755 }
756
757 if idx == -1 {
758 return 0
759 }
760
761 var n int64
762 for _, b := range ct[idx:] {
763 n = n*10 + int64(b-'0')
764 }
765
766 return n
767 }
768
769 func (ct CommandTag) String() string {
770 return string(ct)
771 }
772
773
774 func (ct CommandTag) Insert() bool {
775 return len(ct) >= 6 &&
776 ct[0] == 'I' &&
777 ct[1] == 'N' &&
778 ct[2] == 'S' &&
779 ct[3] == 'E' &&
780 ct[4] == 'R' &&
781 ct[5] == 'T'
782 }
783
784
785 func (ct CommandTag) Update() bool {
786 return len(ct) >= 6 &&
787 ct[0] == 'U' &&
788 ct[1] == 'P' &&
789 ct[2] == 'D' &&
790 ct[3] == 'A' &&
791 ct[4] == 'T' &&
792 ct[5] == 'E'
793 }
794
795
796 func (ct CommandTag) Delete() bool {
797 return len(ct) >= 6 &&
798 ct[0] == 'D' &&
799 ct[1] == 'E' &&
800 ct[2] == 'L' &&
801 ct[3] == 'E' &&
802 ct[4] == 'T' &&
803 ct[5] == 'E'
804 }
805
806
807 func (ct CommandTag) Select() bool {
808 return len(ct) >= 6 &&
809 ct[0] == 'S' &&
810 ct[1] == 'E' &&
811 ct[2] == 'L' &&
812 ct[3] == 'E' &&
813 ct[4] == 'C' &&
814 ct[5] == 'T'
815 }
816
817 type StatementDescription struct {
818 Name string
819 SQL string
820 ParamOIDs []uint32
821 Fields []pgproto3.FieldDescription
822 }
823
824
825
826 func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
827 if err := pgConn.lock(); err != nil {
828 return nil, err
829 }
830 defer pgConn.unlock()
831
832 if ctx != context.Background() {
833 select {
834 case <-ctx.Done():
835 return nil, newContextAlreadyDoneError(ctx)
836 default:
837 }
838 pgConn.contextWatcher.Watch(ctx)
839 defer pgConn.contextWatcher.Unwatch()
840 }
841
842 buf := pgConn.wbuf
843 var err error
844 buf, err = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
845 if err != nil {
846 return nil, err
847 }
848 buf, err = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf)
849 if err != nil {
850 return nil, err
851 }
852 buf, err = (&pgproto3.Sync{}).Encode(buf)
853 if err != nil {
854 return nil, err
855 }
856
857 n, err := pgConn.conn.Write(buf)
858 if err != nil {
859 pgConn.asyncClose()
860 return nil, &writeError{err: err, safeToRetry: n == 0}
861 }
862
863 psd := &StatementDescription{Name: name, SQL: sql}
864
865 var parseErr error
866
867 readloop:
868 for {
869 msg, err := pgConn.receiveMessage()
870 if err != nil {
871 pgConn.asyncClose()
872 return nil, preferContextOverNetTimeoutError(ctx, err)
873 }
874
875 switch msg := msg.(type) {
876 case *pgproto3.ParameterDescription:
877 psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs))
878 copy(psd.ParamOIDs, msg.ParameterOIDs)
879 case *pgproto3.RowDescription:
880 psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields))
881 copy(psd.Fields, msg.Fields)
882 case *pgproto3.ErrorResponse:
883 parseErr = ErrorResponseToPgError(msg)
884 case *pgproto3.ReadyForQuery:
885 break readloop
886 }
887 }
888
889 if parseErr != nil {
890 return nil, parseErr
891 }
892 return psd, nil
893 }
894
895
896 func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
897 return &PgError{
898 Severity: msg.Severity,
899 Code: string(msg.Code),
900 Message: string(msg.Message),
901 Detail: string(msg.Detail),
902 Hint: msg.Hint,
903 Position: msg.Position,
904 InternalPosition: msg.InternalPosition,
905 InternalQuery: string(msg.InternalQuery),
906 Where: string(msg.Where),
907 SchemaName: string(msg.SchemaName),
908 TableName: string(msg.TableName),
909 ColumnName: string(msg.ColumnName),
910 DataTypeName: string(msg.DataTypeName),
911 ConstraintName: msg.ConstraintName,
912 File: string(msg.File),
913 Line: msg.Line,
914 Routine: string(msg.Routine),
915 }
916 }
917
918 func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice {
919 pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg))
920 return (*Notice)(pgerr)
921 }
922
923
924
925
926 func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
927
928
929
930 serverAddr := pgConn.conn.RemoteAddr()
931 cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
932 if err != nil {
933 return err
934 }
935 defer cancelConn.Close()
936
937 if ctx != context.Background() {
938 contextWatcher := ctxwatch.NewContextWatcher(
939 func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
940 func() { cancelConn.SetDeadline(time.Time{}) },
941 )
942 contextWatcher.Watch(ctx)
943 defer contextWatcher.Unwatch()
944 }
945
946 buf := make([]byte, 16)
947 binary.BigEndian.PutUint32(buf[0:4], 16)
948 binary.BigEndian.PutUint32(buf[4:8], 80877102)
949 binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid))
950 binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
951 _, err = cancelConn.Write(buf)
952 if err != nil {
953 return err
954 }
955
956 _, err = cancelConn.Read(buf)
957 if err != io.EOF {
958 return err
959 }
960
961 return nil
962 }
963
964
965
966 func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
967 if err := pgConn.lock(); err != nil {
968 return err
969 }
970 defer pgConn.unlock()
971
972 if ctx != context.Background() {
973 select {
974 case <-ctx.Done():
975 return newContextAlreadyDoneError(ctx)
976 default:
977 }
978
979 pgConn.contextWatcher.Watch(ctx)
980 defer pgConn.contextWatcher.Unwatch()
981 }
982
983 for {
984 msg, err := pgConn.receiveMessage()
985 if err != nil {
986 return preferContextOverNetTimeoutError(ctx, err)
987 }
988
989 switch msg.(type) {
990 case *pgproto3.NotificationResponse:
991 return nil
992 }
993 }
994 }
995
996
997
998
999
1000
1001 func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
1002 if err := pgConn.lock(); err != nil {
1003 return &MultiResultReader{
1004 closed: true,
1005 err: err,
1006 }
1007 }
1008
1009 pgConn.multiResultReader = MultiResultReader{
1010 pgConn: pgConn,
1011 ctx: ctx,
1012 }
1013 multiResult := &pgConn.multiResultReader
1014 if ctx != context.Background() {
1015 select {
1016 case <-ctx.Done():
1017 multiResult.closed = true
1018 multiResult.err = newContextAlreadyDoneError(ctx)
1019 pgConn.unlock()
1020 return multiResult
1021 default:
1022 }
1023 pgConn.contextWatcher.Watch(ctx)
1024 }
1025
1026 buf := pgConn.wbuf
1027 var err error
1028 buf, err = (&pgproto3.Query{String: sql}).Encode(buf)
1029 if err != nil {
1030 return &MultiResultReader{
1031 closed: true,
1032 err: err,
1033 }
1034 }
1035
1036 n, err := pgConn.conn.Write(buf)
1037 if err != nil {
1038 pgConn.asyncClose()
1039 pgConn.contextWatcher.Unwatch()
1040 multiResult.closed = true
1041 multiResult.err = &writeError{err: err, safeToRetry: n == 0}
1042 pgConn.unlock()
1043 return multiResult
1044 }
1045
1046 return multiResult
1047 }
1048
1049
1050
1051
1052
1053
1054 func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader {
1055 if err := pgConn.lock(); err != nil {
1056 return &MultiResultReader{
1057 closed: true,
1058 err: err,
1059 }
1060 }
1061
1062 pgConn.multiResultReader = MultiResultReader{
1063 pgConn: pgConn,
1064 ctx: ctx,
1065 }
1066 multiResult := &pgConn.multiResultReader
1067 if ctx != context.Background() {
1068 select {
1069 case <-ctx.Done():
1070 multiResult.closed = true
1071 multiResult.err = newContextAlreadyDoneError(ctx)
1072 pgConn.unlock()
1073 return multiResult
1074 default:
1075 }
1076 pgConn.contextWatcher.Watch(ctx)
1077 }
1078
1079 return multiResult
1080 }
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101 func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader {
1102 result := pgConn.execExtendedPrefix(ctx, paramValues)
1103 if result.closed {
1104 return result
1105 }
1106
1107 buf := pgConn.wbuf
1108 var err error
1109 buf, err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
1110 if err != nil {
1111 result.concludeCommand(nil, err)
1112 pgConn.contextWatcher.Unwatch()
1113 result.closed = true
1114 pgConn.unlock()
1115 return result
1116 }
1117
1118 buf, err = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
1119 if err != nil {
1120 result.concludeCommand(nil, err)
1121 pgConn.contextWatcher.Unwatch()
1122 result.closed = true
1123 pgConn.unlock()
1124 return result
1125 }
1126
1127 pgConn.execExtendedSuffix(buf, result)
1128
1129 return result
1130 }
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144 func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader {
1145 result := pgConn.execExtendedPrefix(ctx, paramValues)
1146 if result.closed {
1147 return result
1148 }
1149
1150 buf := pgConn.wbuf
1151 var err error
1152 buf, err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
1153 if err != nil {
1154 result.concludeCommand(nil, err)
1155 pgConn.contextWatcher.Unwatch()
1156 result.closed = true
1157 pgConn.unlock()
1158 return result
1159 }
1160
1161 pgConn.execExtendedSuffix(buf, result)
1162
1163 return result
1164 }
1165
1166 func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader {
1167 pgConn.resultReader = ResultReader{
1168 pgConn: pgConn,
1169 ctx: ctx,
1170 }
1171 result := &pgConn.resultReader
1172
1173 if err := pgConn.lock(); err != nil {
1174 result.concludeCommand(nil, err)
1175 result.closed = true
1176 return result
1177 }
1178
1179 if len(paramValues) > math.MaxUint16 {
1180 result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
1181 result.closed = true
1182 pgConn.unlock()
1183 return result
1184 }
1185
1186 if ctx != context.Background() {
1187 select {
1188 case <-ctx.Done():
1189 result.concludeCommand(nil, newContextAlreadyDoneError(ctx))
1190 result.closed = true
1191 pgConn.unlock()
1192 return result
1193 default:
1194 }
1195 pgConn.contextWatcher.Watch(ctx)
1196 }
1197
1198 return result
1199 }
1200
1201 func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
1202 var err error
1203 buf, err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
1204 if err != nil {
1205 result.concludeCommand(nil, err)
1206 pgConn.contextWatcher.Unwatch()
1207 result.closed = true
1208 pgConn.unlock()
1209 return
1210 }
1211 buf, err = (&pgproto3.Execute{}).Encode(buf)
1212 if err != nil {
1213 result.concludeCommand(nil, err)
1214 pgConn.contextWatcher.Unwatch()
1215 result.closed = true
1216 pgConn.unlock()
1217 return
1218 }
1219 buf, err = (&pgproto3.Sync{}).Encode(buf)
1220 if err != nil {
1221 result.concludeCommand(nil, err)
1222 pgConn.contextWatcher.Unwatch()
1223 result.closed = true
1224 pgConn.unlock()
1225 return
1226 }
1227
1228 n, err := pgConn.conn.Write(buf)
1229 if err != nil {
1230 pgConn.asyncClose()
1231 result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0})
1232 pgConn.contextWatcher.Unwatch()
1233 result.closed = true
1234 pgConn.unlock()
1235 return
1236 }
1237
1238 result.readUntilRowDescription()
1239 }
1240
1241
1242 func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
1243 if err := pgConn.lock(); err != nil {
1244 return nil, err
1245 }
1246
1247 if ctx != context.Background() {
1248 select {
1249 case <-ctx.Done():
1250 pgConn.unlock()
1251 return nil, newContextAlreadyDoneError(ctx)
1252 default:
1253 }
1254 pgConn.contextWatcher.Watch(ctx)
1255 defer pgConn.contextWatcher.Unwatch()
1256 }
1257
1258
1259 buf := pgConn.wbuf
1260 var err error
1261 buf, err = (&pgproto3.Query{String: sql}).Encode(buf)
1262 if err != nil {
1263 pgConn.unlock()
1264 return nil, err
1265 }
1266
1267 n, err := pgConn.conn.Write(buf)
1268 if err != nil {
1269 pgConn.asyncClose()
1270 pgConn.unlock()
1271 return nil, &writeError{err: err, safeToRetry: n == 0}
1272 }
1273
1274
1275 var commandTag CommandTag
1276 var pgErr error
1277 for {
1278 msg, err := pgConn.receiveMessage()
1279 if err != nil {
1280 pgConn.asyncClose()
1281 return nil, preferContextOverNetTimeoutError(ctx, err)
1282 }
1283
1284 switch msg := msg.(type) {
1285 case *pgproto3.CopyDone:
1286 case *pgproto3.CopyData:
1287 _, err := w.Write(msg.Data)
1288 if err != nil {
1289 pgConn.asyncClose()
1290 return nil, err
1291 }
1292 case *pgproto3.ReadyForQuery:
1293 pgConn.unlock()
1294 return commandTag, pgErr
1295 case *pgproto3.CommandComplete:
1296 commandTag = CommandTag(msg.CommandTag)
1297 case *pgproto3.ErrorResponse:
1298 pgErr = ErrorResponseToPgError(msg)
1299 }
1300 }
1301 }
1302
1303
1304
1305
1306
1307 func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
1308 if err := pgConn.lock(); err != nil {
1309 return nil, err
1310 }
1311 defer pgConn.unlock()
1312
1313 if ctx != context.Background() {
1314 select {
1315 case <-ctx.Done():
1316 return nil, newContextAlreadyDoneError(ctx)
1317 default:
1318 }
1319 pgConn.contextWatcher.Watch(ctx)
1320 defer pgConn.contextWatcher.Unwatch()
1321 }
1322
1323
1324 buf := pgConn.wbuf
1325 var err error
1326 buf, err = (&pgproto3.Query{String: sql}).Encode(buf)
1327 if err != nil {
1328 pgConn.unlock()
1329 return nil, err
1330 }
1331
1332 n, err := pgConn.conn.Write(buf)
1333 if err != nil {
1334 pgConn.asyncClose()
1335 return nil, &writeError{err: err, safeToRetry: n == 0}
1336 }
1337
1338
1339 abortCopyChan := make(chan struct{})
1340 copyErrChan := make(chan error, 1)
1341 signalMessageChan := pgConn.signalMessage()
1342 var wg sync.WaitGroup
1343 wg.Add(1)
1344
1345 go func() {
1346 defer wg.Done()
1347 buf := make([]byte, 0, 65536)
1348 buf = append(buf, 'd')
1349 sp := len(buf)
1350
1351 for {
1352 n, readErr := r.Read(buf[5:cap(buf)])
1353 if n > 0 {
1354 buf = buf[0 : n+5]
1355 pgio.SetInt32(buf[sp:], int32(n+4))
1356
1357 _, writeErr := pgConn.conn.Write(buf)
1358 if writeErr != nil {
1359
1360 pgConn.conn.Close()
1361
1362 copyErrChan <- writeErr
1363 return
1364 }
1365 }
1366 if readErr != nil {
1367 copyErrChan <- readErr
1368 return
1369 }
1370
1371 select {
1372 case <-abortCopyChan:
1373 return
1374 default:
1375 }
1376 }
1377 }()
1378
1379 var pgErr error
1380 var copyErr error
1381 for copyErr == nil && pgErr == nil {
1382 select {
1383 case copyErr = <-copyErrChan:
1384 case <-signalMessageChan:
1385 msg, err := pgConn.receiveMessage()
1386 if err != nil {
1387 pgConn.asyncClose()
1388 return nil, preferContextOverNetTimeoutError(ctx, err)
1389 }
1390
1391 switch msg := msg.(type) {
1392 case *pgproto3.ErrorResponse:
1393 pgErr = ErrorResponseToPgError(msg)
1394 default:
1395 signalMessageChan = pgConn.signalMessage()
1396 }
1397 }
1398 }
1399 close(abortCopyChan)
1400
1401 wg.Wait()
1402
1403 buf = buf[:0]
1404 if copyErr == io.EOF || pgErr != nil {
1405 copyDone := &pgproto3.CopyDone{}
1406 var err error
1407 buf, err = copyDone.Encode(buf)
1408 if err != nil {
1409 pgConn.asyncClose()
1410 return nil, err
1411 }
1412 } else {
1413 copyFail := &pgproto3.CopyFail{Message: copyErr.Error()}
1414 var err error
1415 buf, err = copyFail.Encode(buf)
1416 if err != nil {
1417 pgConn.asyncClose()
1418 return nil, err
1419 }
1420 }
1421 _, err = pgConn.conn.Write(buf)
1422 if err != nil {
1423 pgConn.asyncClose()
1424 return nil, err
1425 }
1426
1427
1428 var commandTag CommandTag
1429 for {
1430 msg, err := pgConn.receiveMessage()
1431 if err != nil {
1432 pgConn.asyncClose()
1433 return nil, preferContextOverNetTimeoutError(ctx, err)
1434 }
1435
1436 switch msg := msg.(type) {
1437 case *pgproto3.ReadyForQuery:
1438 return commandTag, pgErr
1439 case *pgproto3.CommandComplete:
1440 commandTag = CommandTag(msg.CommandTag)
1441 case *pgproto3.ErrorResponse:
1442 pgErr = ErrorResponseToPgError(msg)
1443 }
1444 }
1445 }
1446
1447
1448 type MultiResultReader struct {
1449 pgConn *PgConn
1450 ctx context.Context
1451
1452 rr *ResultReader
1453
1454 closed bool
1455 err error
1456 }
1457
1458
1459 func (mrr *MultiResultReader) ReadAll() ([]*Result, error) {
1460 var results []*Result
1461
1462 for mrr.NextResult() {
1463 results = append(results, mrr.ResultReader().Read())
1464 }
1465 err := mrr.Close()
1466
1467 return results, err
1468 }
1469
1470 func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
1471 msg, err := mrr.pgConn.receiveMessage()
1472
1473 if err != nil {
1474 mrr.pgConn.contextWatcher.Unwatch()
1475 mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
1476 mrr.closed = true
1477 mrr.pgConn.asyncClose()
1478 return nil, mrr.err
1479 }
1480
1481 switch msg := msg.(type) {
1482 case *pgproto3.ReadyForQuery:
1483 mrr.pgConn.contextWatcher.Unwatch()
1484 mrr.closed = true
1485 mrr.pgConn.unlock()
1486 case *pgproto3.ErrorResponse:
1487 mrr.err = ErrorResponseToPgError(msg)
1488 }
1489
1490 return msg, nil
1491 }
1492
1493
1494 func (mrr *MultiResultReader) NextResult() bool {
1495 for !mrr.closed && mrr.err == nil {
1496 msg, err := mrr.receiveMessage()
1497 if err != nil {
1498 return false
1499 }
1500
1501 switch msg := msg.(type) {
1502 case *pgproto3.RowDescription:
1503 mrr.pgConn.resultReader = ResultReader{
1504 pgConn: mrr.pgConn,
1505 multiResultReader: mrr,
1506 ctx: mrr.ctx,
1507 fieldDescriptions: msg.Fields,
1508 }
1509 mrr.rr = &mrr.pgConn.resultReader
1510 return true
1511 case *pgproto3.CommandComplete:
1512 mrr.pgConn.resultReader = ResultReader{
1513 commandTag: CommandTag(msg.CommandTag),
1514 commandConcluded: true,
1515 closed: true,
1516 }
1517 mrr.rr = &mrr.pgConn.resultReader
1518 return true
1519 case *pgproto3.EmptyQueryResponse:
1520 return false
1521 }
1522 }
1523
1524 return false
1525 }
1526
1527
1528 func (mrr *MultiResultReader) ResultReader() *ResultReader {
1529 return mrr.rr
1530 }
1531
1532
1533 func (mrr *MultiResultReader) Close() error {
1534 for !mrr.closed {
1535 _, err := mrr.receiveMessage()
1536 if err != nil {
1537 return mrr.err
1538 }
1539 }
1540
1541 return mrr.err
1542 }
1543
1544
1545 type ResultReader struct {
1546 pgConn *PgConn
1547 multiResultReader *MultiResultReader
1548 ctx context.Context
1549
1550 fieldDescriptions []pgproto3.FieldDescription
1551 rowValues [][]byte
1552 commandTag CommandTag
1553 commandConcluded bool
1554 closed bool
1555 err error
1556 }
1557
1558
1559 type Result struct {
1560 FieldDescriptions []pgproto3.FieldDescription
1561 Rows [][][]byte
1562 CommandTag CommandTag
1563 Err error
1564 }
1565
1566
1567 func (rr *ResultReader) Read() *Result {
1568 br := &Result{}
1569
1570 for rr.NextRow() {
1571 if br.FieldDescriptions == nil {
1572 br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions()))
1573 copy(br.FieldDescriptions, rr.FieldDescriptions())
1574 }
1575
1576 row := make([][]byte, len(rr.Values()))
1577 copy(row, rr.Values())
1578 br.Rows = append(br.Rows, row)
1579 }
1580
1581 br.CommandTag, br.Err = rr.Close()
1582
1583 return br
1584 }
1585
1586
1587 func (rr *ResultReader) NextRow() bool {
1588 for !rr.commandConcluded {
1589 msg, err := rr.receiveMessage()
1590 if err != nil {
1591 return false
1592 }
1593
1594 switch msg := msg.(type) {
1595 case *pgproto3.DataRow:
1596 rr.rowValues = msg.Values
1597 return true
1598 }
1599 }
1600
1601 return false
1602 }
1603
1604
1605
1606 func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription {
1607 return rr.fieldDescriptions
1608 }
1609
1610
1611
1612
1613 func (rr *ResultReader) Values() [][]byte {
1614 return rr.rowValues
1615 }
1616
1617
1618
1619 func (rr *ResultReader) Close() (CommandTag, error) {
1620 if rr.closed {
1621 return rr.commandTag, rr.err
1622 }
1623 rr.closed = true
1624
1625 for !rr.commandConcluded {
1626 _, err := rr.receiveMessage()
1627 if err != nil {
1628 return nil, rr.err
1629 }
1630 }
1631
1632 if rr.multiResultReader == nil {
1633 for {
1634 msg, err := rr.receiveMessage()
1635 if err != nil {
1636 return nil, rr.err
1637 }
1638
1639 switch msg := msg.(type) {
1640
1641 case *pgproto3.ErrorResponse:
1642 rr.err = ErrorResponseToPgError(msg)
1643 case *pgproto3.ReadyForQuery:
1644 rr.pgConn.contextWatcher.Unwatch()
1645 rr.pgConn.unlock()
1646 return rr.commandTag, rr.err
1647 }
1648 }
1649 }
1650
1651 return rr.commandTag, rr.err
1652 }
1653
1654
1655
1656 func (rr *ResultReader) readUntilRowDescription() {
1657 for !rr.commandConcluded {
1658
1659
1660
1661 msg, _ := rr.pgConn.peekMessage()
1662 if _, ok := msg.(*pgproto3.DataRow); ok {
1663 return
1664 }
1665
1666
1667 msg, _ = rr.receiveMessage()
1668 if _, ok := msg.(*pgproto3.RowDescription); ok {
1669 return
1670 }
1671 }
1672 }
1673
1674 func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
1675 if rr.multiResultReader == nil {
1676 msg, err = rr.pgConn.receiveMessage()
1677 } else {
1678 msg, err = rr.multiResultReader.receiveMessage()
1679 }
1680
1681 if err != nil {
1682 err = preferContextOverNetTimeoutError(rr.ctx, err)
1683 rr.concludeCommand(nil, err)
1684 rr.pgConn.contextWatcher.Unwatch()
1685 rr.closed = true
1686 if rr.multiResultReader == nil {
1687 rr.pgConn.asyncClose()
1688 }
1689
1690 return nil, rr.err
1691 }
1692
1693 switch msg := msg.(type) {
1694 case *pgproto3.RowDescription:
1695 rr.fieldDescriptions = msg.Fields
1696 case *pgproto3.CommandComplete:
1697 rr.concludeCommand(CommandTag(msg.CommandTag), nil)
1698 case *pgproto3.EmptyQueryResponse:
1699 rr.concludeCommand(nil, nil)
1700 case *pgproto3.ErrorResponse:
1701 rr.concludeCommand(nil, ErrorResponseToPgError(msg))
1702 }
1703
1704 return msg, nil
1705 }
1706
1707 func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
1708
1709
1710 if err != nil && rr.err == nil {
1711 rr.err = err
1712 }
1713
1714 if rr.commandConcluded {
1715 return
1716 }
1717
1718 rr.commandTag = commandTag
1719 rr.rowValues = nil
1720 rr.commandConcluded = true
1721 }
1722
1723
1724 type Batch struct {
1725 buf []byte
1726 err error
1727 }
1728
1729
1730 func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
1731 if batch.err != nil {
1732 return
1733 }
1734
1735 batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
1736 if batch.err != nil {
1737 return
1738 }
1739 batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
1740 }
1741
1742
1743 func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
1744 if batch.err != nil {
1745 return
1746 }
1747
1748 batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
1749 if batch.err != nil {
1750 return
1751 }
1752
1753 batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
1754 if batch.err != nil {
1755 return
1756 }
1757
1758 batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
1759 if batch.err != nil {
1760 return
1761 }
1762 }
1763
1764
1765
1766 func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
1767 if batch.err != nil {
1768 return &MultiResultReader{
1769 closed: true,
1770 err: batch.err,
1771 }
1772 }
1773
1774 if err := pgConn.lock(); err != nil {
1775 return &MultiResultReader{
1776 closed: true,
1777 err: err,
1778 }
1779 }
1780
1781 pgConn.multiResultReader = MultiResultReader{
1782 pgConn: pgConn,
1783 ctx: ctx,
1784 }
1785 multiResult := &pgConn.multiResultReader
1786
1787 if ctx != context.Background() {
1788 select {
1789 case <-ctx.Done():
1790 multiResult.closed = true
1791 multiResult.err = newContextAlreadyDoneError(ctx)
1792 pgConn.unlock()
1793 return multiResult
1794 default:
1795 }
1796 pgConn.contextWatcher.Watch(ctx)
1797 }
1798
1799 batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
1800 if batch.err != nil {
1801 multiResult.closed = true
1802 multiResult.err = batch.err
1803 pgConn.unlock()
1804 return multiResult
1805 }
1806
1807
1808
1809
1810
1811
1812
1813 go func() {
1814 _, err := pgConn.conn.Write(batch.buf)
1815 if err != nil {
1816 pgConn.conn.Close()
1817 }
1818 }()
1819
1820 return multiResult
1821 }
1822
1823
1824
1825
1826
1827
1828 func (pgConn *PgConn) EscapeString(s string) (string, error) {
1829 if pgConn.ParameterStatus("standard_conforming_strings") != "on" {
1830 return "", errors.New("EscapeString must be run with standard_conforming_strings=on")
1831 }
1832
1833 if pgConn.ParameterStatus("client_encoding") != "UTF8" {
1834 return "", errors.New("EscapeString must be run with client_encoding=UTF8")
1835 }
1836
1837 return strings.Replace(s, "'", "''", -1), nil
1838 }
1839
1840
1841
1842
1843
1844 type HijackedConn struct {
1845 Conn net.Conn
1846 PID uint32
1847 SecretKey uint32
1848 ParameterStatuses map[string]string
1849 TxStatus byte
1850 Frontend Frontend
1851 Config *Config
1852 }
1853
1854
1855
1856
1857
1858
1859
1860 func (pgConn *PgConn) Hijack() (*HijackedConn, error) {
1861 if err := pgConn.lock(); err != nil {
1862 return nil, err
1863 }
1864 pgConn.status = connStatusClosed
1865
1866 return &HijackedConn{
1867 Conn: pgConn.conn,
1868 PID: pgConn.pid,
1869 SecretKey: pgConn.secretKey,
1870 ParameterStatuses: pgConn.parameterStatuses,
1871 TxStatus: pgConn.txStatus,
1872 Frontend: pgConn.frontend,
1873 Config: pgConn.config,
1874 }, nil
1875 }
1876
1877
1878
1879
1880
1881
1882 func Construct(hc *HijackedConn) (*PgConn, error) {
1883 pgConn := &PgConn{
1884 conn: hc.Conn,
1885 pid: hc.PID,
1886 secretKey: hc.SecretKey,
1887 parameterStatuses: hc.ParameterStatuses,
1888 txStatus: hc.TxStatus,
1889 frontend: hc.Frontend,
1890 config: hc.Config,
1891
1892 status: connStatusIdle,
1893
1894 wbuf: make([]byte, 0, wbufLen),
1895 cleanupDone: make(chan struct{}),
1896 }
1897
1898 pgConn.contextWatcher = newContextWatcher(pgConn.conn)
1899
1900 return pgConn, nil
1901 }
1902
View as plain text