1 package pgx
2
3 import (
4 "context"
5 "crypto/sha256"
6 "encoding/hex"
7 "errors"
8 "fmt"
9 "strconv"
10 "strings"
11 "time"
12
13 "github.com/jackc/pgx/v5/internal/anynil"
14 "github.com/jackc/pgx/v5/internal/sanitize"
15 "github.com/jackc/pgx/v5/internal/stmtcache"
16 "github.com/jackc/pgx/v5/pgconn"
17 "github.com/jackc/pgx/v5/pgtype"
18 )
19
20
21
22 type ConnConfig struct {
23 pgconn.Config
24
25 Tracer QueryTracer
26
27
28 connString string
29
30
31
32 StatementCacheCapacity int
33
34
35
36 DescriptionCacheCapacity int
37
38
39
40
41
42 DefaultQueryExecMode QueryExecMode
43
44 createdByParseConfig bool
45 }
46
47
48 type ParseConfigOptions struct {
49 pgconn.ParseConfigOptions
50 }
51
52
53
54
55 func (cc *ConnConfig) Copy() *ConnConfig {
56 newConfig := new(ConnConfig)
57 *newConfig = *cc
58 newConfig.Config = *newConfig.Config.Copy()
59 return newConfig
60 }
61
62
63 func (cc *ConnConfig) ConnString() string { return cc.connString }
64
65
66
67 type Conn struct {
68 pgConn *pgconn.PgConn
69 config *ConnConfig
70 preparedStatements map[string]*pgconn.StatementDescription
71 statementCache stmtcache.Cache
72 descriptionCache stmtcache.Cache
73
74 queryTracer QueryTracer
75 batchTracer BatchTracer
76 copyFromTracer CopyFromTracer
77 prepareTracer PrepareTracer
78
79 notifications []*pgconn.Notification
80
81 doneChan chan struct{}
82 closedChan chan error
83
84 typeMap *pgtype.Map
85
86 wbuf []byte
87 eqb ExtendedQueryBuilder
88 }
89
90
91
92 type Identifier []string
93
94
95 func (ident Identifier) Sanitize() string {
96 parts := make([]string, len(ident))
97 for i := range ident {
98 s := strings.ReplaceAll(ident[i], string([]byte{0}), "")
99 parts[i] = `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
100 }
101 return strings.Join(parts, ".")
102 }
103
104 var (
105
106 ErrNoRows = errors.New("no rows in result set")
107
108 ErrTooManyRows = errors.New("too many rows in result set")
109 )
110
111 var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
112 var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
113
114
115
116 func Connect(ctx context.Context, connString string) (*Conn, error) {
117 connConfig, err := ParseConfig(connString)
118 if err != nil {
119 return nil, err
120 }
121 return connect(ctx, connConfig)
122 }
123
124
125
126 func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) {
127 connConfig, err := ParseConfigWithOptions(connString, options)
128 if err != nil {
129 return nil, err
130 }
131 return connect(ctx, connConfig)
132 }
133
134
135
136 func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
137
138
139 connConfig = connConfig.Copy()
140
141 return connect(ctx, connConfig)
142 }
143
144
145
146 func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*ConnConfig, error) {
147 config, err := pgconn.ParseConfigWithOptions(connString, options.ParseConfigOptions)
148 if err != nil {
149 return nil, err
150 }
151
152 statementCacheCapacity := 512
153 if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok {
154 delete(config.RuntimeParams, "statement_cache_capacity")
155 n, err := strconv.ParseInt(s, 10, 32)
156 if err != nil {
157 return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err)
158 }
159 statementCacheCapacity = int(n)
160 }
161
162 descriptionCacheCapacity := 512
163 if s, ok := config.RuntimeParams["description_cache_capacity"]; ok {
164 delete(config.RuntimeParams, "description_cache_capacity")
165 n, err := strconv.ParseInt(s, 10, 32)
166 if err != nil {
167 return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err)
168 }
169 descriptionCacheCapacity = int(n)
170 }
171
172 defaultQueryExecMode := QueryExecModeCacheStatement
173 if s, ok := config.RuntimeParams["default_query_exec_mode"]; ok {
174 delete(config.RuntimeParams, "default_query_exec_mode")
175 switch s {
176 case "cache_statement":
177 defaultQueryExecMode = QueryExecModeCacheStatement
178 case "cache_describe":
179 defaultQueryExecMode = QueryExecModeCacheDescribe
180 case "describe_exec":
181 defaultQueryExecMode = QueryExecModeDescribeExec
182 case "exec":
183 defaultQueryExecMode = QueryExecModeExec
184 case "simple_protocol":
185 defaultQueryExecMode = QueryExecModeSimpleProtocol
186 default:
187 return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s)
188 }
189 }
190
191 connConfig := &ConnConfig{
192 Config: *config,
193 createdByParseConfig: true,
194 StatementCacheCapacity: statementCacheCapacity,
195 DescriptionCacheCapacity: descriptionCacheCapacity,
196 DefaultQueryExecMode: defaultQueryExecMode,
197 connString: connString,
198 }
199
200 return connConfig, nil
201 }
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217 func ParseConfig(connString string) (*ConnConfig, error) {
218 return ParseConfigWithOptions(connString, ParseConfigOptions{})
219 }
220
221
222 func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
223 if connectTracer, ok := config.Tracer.(ConnectTracer); ok {
224 ctx = connectTracer.TraceConnectStart(ctx, TraceConnectStartData{ConnConfig: config})
225 defer func() {
226 connectTracer.TraceConnectEnd(ctx, TraceConnectEndData{Conn: c, Err: err})
227 }()
228 }
229
230
231
232 if !config.createdByParseConfig {
233 panic("config must be created by ParseConfig")
234 }
235
236 c = &Conn{
237 config: config,
238 typeMap: pgtype.NewMap(),
239 queryTracer: config.Tracer,
240 }
241
242 if t, ok := c.queryTracer.(BatchTracer); ok {
243 c.batchTracer = t
244 }
245 if t, ok := c.queryTracer.(CopyFromTracer); ok {
246 c.copyFromTracer = t
247 }
248 if t, ok := c.queryTracer.(PrepareTracer); ok {
249 c.prepareTracer = t
250 }
251
252
253 if config.Config.OnNotification == nil {
254 config.Config.OnNotification = c.bufferNotifications
255 }
256
257 c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
258 if err != nil {
259 return nil, err
260 }
261
262 c.preparedStatements = make(map[string]*pgconn.StatementDescription)
263 c.doneChan = make(chan struct{})
264 c.closedChan = make(chan error)
265 c.wbuf = make([]byte, 0, 1024)
266
267 if c.config.StatementCacheCapacity > 0 {
268 c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
269 }
270
271 if c.config.DescriptionCacheCapacity > 0 {
272 c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
273 }
274
275 return c, nil
276 }
277
278
279
280 func (c *Conn) Close(ctx context.Context) error {
281 if c.IsClosed() {
282 return nil
283 }
284
285 err := c.pgConn.Close(ctx)
286 return err
287 }
288
289
290
291
292
293
294
295
296
297
298 func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
299 if c.prepareTracer != nil {
300 ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql})
301 }
302
303 if name != "" {
304 var ok bool
305 if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql {
306 if c.prepareTracer != nil {
307 c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{AlreadyPrepared: true})
308 }
309 return sd, nil
310 }
311 }
312
313 if c.prepareTracer != nil {
314 defer func() {
315 c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{Err: err})
316 }()
317 }
318
319 var psName, psKey string
320 if name == sql {
321 digest := sha256.Sum256([]byte(sql))
322 psName = "stmt_" + hex.EncodeToString(digest[0:24])
323 psKey = sql
324 } else {
325 psName = name
326 psKey = name
327 }
328
329 sd, err = c.pgConn.Prepare(ctx, psName, sql, nil)
330 if err != nil {
331 return nil, err
332 }
333
334 if psKey != "" {
335 c.preparedStatements[psKey] = sd
336 }
337
338 return sd, nil
339 }
340
341
342 func (c *Conn) Deallocate(ctx context.Context, name string) error {
343 var psName string
344 sd := c.preparedStatements[name]
345 if sd != nil {
346 psName = sd.Name
347 } else {
348 psName = name
349 }
350
351 err := c.pgConn.Deallocate(ctx, psName)
352 if err != nil {
353 return err
354 }
355
356 if sd != nil {
357 delete(c.preparedStatements, name)
358 }
359
360 return nil
361 }
362
363
364 func (c *Conn) DeallocateAll(ctx context.Context) error {
365 c.preparedStatements = map[string]*pgconn.StatementDescription{}
366 if c.config.StatementCacheCapacity > 0 {
367 c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
368 }
369 if c.config.DescriptionCacheCapacity > 0 {
370 c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
371 }
372 _, err := c.pgConn.Exec(ctx, "deallocate all").ReadAll()
373 return err
374 }
375
376 func (c *Conn) bufferNotifications(_ *pgconn.PgConn, n *pgconn.Notification) {
377 c.notifications = append(c.notifications, n)
378 }
379
380
381
382 func (c *Conn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) {
383 var n *pgconn.Notification
384
385
386 if len(c.notifications) > 0 {
387 n = c.notifications[0]
388 c.notifications = c.notifications[1:]
389 return n, nil
390 }
391
392 err := c.pgConn.WaitForNotification(ctx)
393 if len(c.notifications) > 0 {
394 n = c.notifications[0]
395 c.notifications = c.notifications[1:]
396 }
397 return n, err
398 }
399
400
401 func (c *Conn) IsClosed() bool {
402 return c.pgConn.IsClosed()
403 }
404
405 func (c *Conn) die(err error) {
406 if c.IsClosed() {
407 return
408 }
409
410 ctx, cancel := context.WithCancel(context.Background())
411 cancel()
412 c.pgConn.Close(ctx)
413 }
414
415 func quoteIdentifier(s string) string {
416 return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
417 }
418
419
420 func (c *Conn) Ping(ctx context.Context) error {
421 return c.pgConn.Ping(ctx)
422 }
423
424
425
426
427
428
429 func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn }
430
431
432 func (c *Conn) TypeMap() *pgtype.Map { return c.typeMap }
433
434
435 func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
436
437
438
439 func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
440 if c.queryTracer != nil {
441 ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: arguments})
442 }
443
444 if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
445 return pgconn.CommandTag{}, err
446 }
447
448 commandTag, err := c.exec(ctx, sql, arguments...)
449
450 if c.queryTracer != nil {
451 c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{CommandTag: commandTag, Err: err})
452 }
453
454 return commandTag, err
455 }
456
457 func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) {
458 mode := c.config.DefaultQueryExecMode
459 var queryRewriter QueryRewriter
460
461 optionLoop:
462 for len(arguments) > 0 {
463 switch arg := arguments[0].(type) {
464 case QueryExecMode:
465 mode = arg
466 arguments = arguments[1:]
467 case QueryRewriter:
468 queryRewriter = arg
469 arguments = arguments[1:]
470 default:
471 break optionLoop
472 }
473 }
474
475 if queryRewriter != nil {
476 sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
477 if err != nil {
478 return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %w", err)
479 }
480 }
481
482
483 if len(arguments) == 0 {
484 mode = QueryExecModeSimpleProtocol
485 }
486
487 if sd, ok := c.preparedStatements[sql]; ok {
488 return c.execPrepared(ctx, sd, arguments)
489 }
490
491 switch mode {
492 case QueryExecModeCacheStatement:
493 if c.statementCache == nil {
494 return pgconn.CommandTag{}, errDisabledStatementCache
495 }
496 sd := c.statementCache.Get(sql)
497 if sd == nil {
498 sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
499 if err != nil {
500 return pgconn.CommandTag{}, err
501 }
502 c.statementCache.Put(sd)
503 }
504
505 return c.execPrepared(ctx, sd, arguments)
506 case QueryExecModeCacheDescribe:
507 if c.descriptionCache == nil {
508 return pgconn.CommandTag{}, errDisabledDescriptionCache
509 }
510 sd := c.descriptionCache.Get(sql)
511 if sd == nil {
512 sd, err = c.Prepare(ctx, "", sql)
513 if err != nil {
514 return pgconn.CommandTag{}, err
515 }
516 c.descriptionCache.Put(sd)
517 }
518
519 return c.execParams(ctx, sd, arguments)
520 case QueryExecModeDescribeExec:
521 sd, err := c.Prepare(ctx, "", sql)
522 if err != nil {
523 return pgconn.CommandTag{}, err
524 }
525 return c.execPrepared(ctx, sd, arguments)
526 case QueryExecModeExec:
527 return c.execSQLParams(ctx, sql, arguments)
528 case QueryExecModeSimpleProtocol:
529 return c.execSimpleProtocol(ctx, sql, arguments)
530 default:
531 return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", mode)
532 }
533 }
534
535 func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []any) (commandTag pgconn.CommandTag, err error) {
536 if len(arguments) > 0 {
537 sql, err = c.sanitizeForSimpleQuery(sql, arguments...)
538 if err != nil {
539 return pgconn.CommandTag{}, err
540 }
541 }
542
543 mrr := c.pgConn.Exec(ctx, sql)
544 for mrr.NextResult() {
545 commandTag, _ = mrr.ResultReader().Close()
546 }
547 err = mrr.Close()
548 return commandTag, err
549 }
550
551 func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) {
552 err := c.eqb.Build(c.typeMap, sd, arguments)
553 if err != nil {
554 return pgconn.CommandTag{}, err
555 }
556
557 result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats).Read()
558 c.eqb.reset()
559 return result.CommandTag, result.Err
560 }
561
562 func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) {
563 err := c.eqb.Build(c.typeMap, sd, arguments)
564 if err != nil {
565 return pgconn.CommandTag{}, err
566 }
567
568 result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read()
569 c.eqb.reset()
570 return result.CommandTag, result.Err
571 }
572
573 type unknownArgumentTypeQueryExecModeExecError struct {
574 arg any
575 }
576
577 func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
578 return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
579 }
580
581 func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
582 err := c.eqb.Build(c.typeMap, nil, args)
583 if err != nil {
584 return pgconn.CommandTag{}, err
585 }
586
587 result := c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats).Read()
588 c.eqb.reset()
589 return result.CommandTag, result.Err
590 }
591
592 func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows {
593 r := &baseRows{}
594
595 r.ctx = ctx
596 r.queryTracer = c.queryTracer
597 r.typeMap = c.typeMap
598 r.startTime = time.Now()
599 r.sql = sql
600 r.args = args
601 r.conn = c
602
603 return r
604 }
605
606 type QueryExecMode int32
607
608 const (
609 _ QueryExecMode = iota
610
611
612
613
614
615 QueryExecModeCacheStatement
616
617
618
619
620
621 QueryExecModeCacheDescribe
622
623
624
625
626
627
628 QueryExecModeDescribeExec
629
630
631
632
633
634
635 QueryExecModeExec
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650 QueryExecModeSimpleProtocol
651 )
652
653 func (m QueryExecMode) String() string {
654 switch m {
655 case QueryExecModeCacheStatement:
656 return "cache statement"
657 case QueryExecModeCacheDescribe:
658 return "cache describe"
659 case QueryExecModeDescribeExec:
660 return "describe exec"
661 case QueryExecModeExec:
662 return "exec"
663 case QueryExecModeSimpleProtocol:
664 return "simple protocol"
665 default:
666 return "invalid"
667 }
668 }
669
670
671 type QueryResultFormats []int16
672
673
674 type QueryResultFormatsByOID map[uint32]int16
675
676
677 type QueryRewriter interface {
678 RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error)
679 }
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702 func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
703 if c.queryTracer != nil {
704 ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: args})
705 }
706
707 if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
708 if c.queryTracer != nil {
709 c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{Err: err})
710 }
711 return &baseRows{err: err, closed: true}, err
712 }
713
714 var resultFormats QueryResultFormats
715 var resultFormatsByOID QueryResultFormatsByOID
716 mode := c.config.DefaultQueryExecMode
717 var queryRewriter QueryRewriter
718
719 optionLoop:
720 for len(args) > 0 {
721 switch arg := args[0].(type) {
722 case QueryResultFormats:
723 resultFormats = arg
724 args = args[1:]
725 case QueryResultFormatsByOID:
726 resultFormatsByOID = arg
727 args = args[1:]
728 case QueryExecMode:
729 mode = arg
730 args = args[1:]
731 case QueryRewriter:
732 queryRewriter = arg
733 args = args[1:]
734 default:
735 break optionLoop
736 }
737 }
738
739 if queryRewriter != nil {
740 var err error
741 originalSQL := sql
742 originalArgs := args
743 sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args)
744 if err != nil {
745 rows := c.getRows(ctx, originalSQL, originalArgs)
746 err = fmt.Errorf("rewrite query failed: %w", err)
747 rows.fatal(err)
748 return rows, err
749 }
750 }
751
752
753 if sql == "" {
754 mode = QueryExecModeSimpleProtocol
755 }
756
757 c.eqb.reset()
758 anynil.NormalizeSlice(args)
759 rows := c.getRows(ctx, sql, args)
760
761 var err error
762 sd, explicitPreparedStatement := c.preparedStatements[sql]
763 if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
764 if sd == nil {
765 sd, err = c.getStatementDescription(ctx, mode, sql)
766 if err != nil {
767 rows.fatal(err)
768 return rows, err
769 }
770 }
771
772 if len(sd.ParamOIDs) != len(args) {
773 rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)))
774 return rows, rows.err
775 }
776
777 rows.sql = sd.SQL
778
779 err = c.eqb.Build(c.typeMap, sd, args)
780 if err != nil {
781 rows.fatal(err)
782 return rows, rows.err
783 }
784
785 if resultFormatsByOID != nil {
786 resultFormats = make([]int16, len(sd.Fields))
787 for i := range resultFormats {
788 resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)]
789 }
790 }
791
792 if resultFormats == nil {
793 resultFormats = c.eqb.ResultFormats
794 }
795
796 if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe {
797 rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats)
798 } else {
799 rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats)
800 }
801 } else if mode == QueryExecModeExec {
802 err := c.eqb.Build(c.typeMap, nil, args)
803 if err != nil {
804 rows.fatal(err)
805 return rows, rows.err
806 }
807
808 rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
809 } else if mode == QueryExecModeSimpleProtocol {
810 sql, err = c.sanitizeForSimpleQuery(sql, args...)
811 if err != nil {
812 rows.fatal(err)
813 return rows, err
814 }
815
816 mrr := c.pgConn.Exec(ctx, sql)
817 if mrr.NextResult() {
818 rows.resultReader = mrr.ResultReader()
819 rows.multiResultReader = mrr
820 } else {
821 err = mrr.Close()
822 rows.fatal(err)
823 return rows, err
824 }
825
826 return rows, nil
827 } else {
828 err = fmt.Errorf("unknown QueryExecMode: %v", mode)
829 rows.fatal(err)
830 return rows, rows.err
831 }
832
833 c.eqb.reset()
834
835 return rows, rows.err
836 }
837
838
839
840
841
842
843 func (c *Conn) getStatementDescription(
844 ctx context.Context,
845 mode QueryExecMode,
846 sql string,
847 ) (sd *pgconn.StatementDescription, err error) {
848
849 switch mode {
850 case QueryExecModeCacheStatement:
851 if c.statementCache == nil {
852 return nil, errDisabledStatementCache
853 }
854 sd = c.statementCache.Get(sql)
855 if sd == nil {
856 sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
857 if err != nil {
858 return nil, err
859 }
860 c.statementCache.Put(sd)
861 }
862 case QueryExecModeCacheDescribe:
863 if c.descriptionCache == nil {
864 return nil, errDisabledDescriptionCache
865 }
866 sd = c.descriptionCache.Get(sql)
867 if sd == nil {
868 sd, err = c.Prepare(ctx, "", sql)
869 if err != nil {
870 return nil, err
871 }
872 c.descriptionCache.Put(sd)
873 }
874 case QueryExecModeDescribeExec:
875 return c.Prepare(ctx, "", sql)
876 }
877 return sd, err
878 }
879
880
881
882
883 func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
884 rows, _ := c.Query(ctx, sql, args...)
885 return (*connRow)(rows.(*baseRows))
886 }
887
888
889
890
891 func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
892 if c.batchTracer != nil {
893 ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b})
894 defer func() {
895 err := br.(interface{ earlyError() error }).earlyError()
896 if err != nil {
897 c.batchTracer.TraceBatchEnd(ctx, c, TraceBatchEndData{Err: err})
898 }
899 }()
900 }
901
902 if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
903 return &batchResults{ctx: ctx, conn: c, err: err}
904 }
905
906 for _, bi := range b.QueuedQueries {
907 var queryRewriter QueryRewriter
908 sql := bi.SQL
909 arguments := bi.Arguments
910
911 optionLoop:
912 for len(arguments) > 0 {
913
914 switch arg := arguments[0].(type) {
915 case QueryRewriter:
916 queryRewriter = arg
917 arguments = arguments[1:]
918 default:
919 break optionLoop
920 }
921 }
922
923 if queryRewriter != nil {
924 var err error
925 sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
926 if err != nil {
927 return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %w", err)}
928 }
929 }
930
931 bi.SQL = sql
932 bi.Arguments = arguments
933 }
934
935
936 mode := c.config.DefaultQueryExecMode
937 if mode == QueryExecModeSimpleProtocol {
938 return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
939 }
940
941
942 for _, bi := range b.QueuedQueries {
943 if sd, ok := c.preparedStatements[bi.SQL]; ok {
944 bi.sd = sd
945 }
946 }
947
948 switch mode {
949 case QueryExecModeExec:
950 return c.sendBatchQueryExecModeExec(ctx, b)
951 case QueryExecModeCacheStatement:
952 return c.sendBatchQueryExecModeCacheStatement(ctx, b)
953 case QueryExecModeCacheDescribe:
954 return c.sendBatchQueryExecModeCacheDescribe(ctx, b)
955 case QueryExecModeDescribeExec:
956 return c.sendBatchQueryExecModeDescribeExec(ctx, b)
957 default:
958 panic("unknown QueryExecMode")
959 }
960 }
961
962 func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
963 var sb strings.Builder
964 for i, bi := range b.QueuedQueries {
965 if i > 0 {
966 sb.WriteByte(';')
967 }
968 sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...)
969 if err != nil {
970 return &batchResults{ctx: ctx, conn: c, err: err}
971 }
972 sb.WriteString(sql)
973 }
974 mrr := c.pgConn.Exec(ctx, sb.String())
975 return &batchResults{
976 ctx: ctx,
977 conn: c,
978 mrr: mrr,
979 b: b,
980 qqIdx: 0,
981 }
982 }
983
984 func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
985 batch := &pgconn.Batch{}
986
987 for _, bi := range b.QueuedQueries {
988 sd := bi.sd
989 if sd != nil {
990 err := c.eqb.Build(c.typeMap, sd, bi.Arguments)
991 if err != nil {
992 return &batchResults{ctx: ctx, conn: c, err: err}
993 }
994
995 batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
996 } else {
997 err := c.eqb.Build(c.typeMap, nil, bi.Arguments)
998 if err != nil {
999 return &batchResults{ctx: ctx, conn: c, err: err}
1000 }
1001 batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
1002 }
1003 }
1004
1005 c.eqb.reset()
1006
1007 mrr := c.pgConn.ExecBatch(ctx, batch)
1008
1009 return &batchResults{
1010 ctx: ctx,
1011 conn: c,
1012 mrr: mrr,
1013 b: b,
1014 qqIdx: 0,
1015 }
1016 }
1017
1018 func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
1019 if c.statementCache == nil {
1020 return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true}
1021 }
1022
1023 distinctNewQueries := []*pgconn.StatementDescription{}
1024 distinctNewQueriesIdxMap := make(map[string]int)
1025
1026 for _, bi := range b.QueuedQueries {
1027 if bi.sd == nil {
1028 sd := c.statementCache.Get(bi.SQL)
1029 if sd != nil {
1030 bi.sd = sd
1031 } else {
1032 if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
1033 bi.sd = distinctNewQueries[idx]
1034 } else {
1035 sd = &pgconn.StatementDescription{
1036 Name: stmtcache.StatementName(bi.SQL),
1037 SQL: bi.SQL,
1038 }
1039 distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
1040 distinctNewQueries = append(distinctNewQueries, sd)
1041 bi.sd = sd
1042 }
1043 }
1044 }
1045 }
1046
1047 return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache)
1048 }
1049
1050 func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
1051 if c.descriptionCache == nil {
1052 return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true}
1053 }
1054
1055 distinctNewQueries := []*pgconn.StatementDescription{}
1056 distinctNewQueriesIdxMap := make(map[string]int)
1057
1058 for _, bi := range b.QueuedQueries {
1059 if bi.sd == nil {
1060 sd := c.descriptionCache.Get(bi.SQL)
1061 if sd != nil {
1062 bi.sd = sd
1063 } else {
1064 if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
1065 bi.sd = distinctNewQueries[idx]
1066 } else {
1067 sd = &pgconn.StatementDescription{
1068 SQL: bi.SQL,
1069 }
1070 distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
1071 distinctNewQueries = append(distinctNewQueries, sd)
1072 bi.sd = sd
1073 }
1074 }
1075 }
1076 }
1077
1078 return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache)
1079 }
1080
1081 func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
1082 distinctNewQueries := []*pgconn.StatementDescription{}
1083 distinctNewQueriesIdxMap := make(map[string]int)
1084
1085 for _, bi := range b.QueuedQueries {
1086 if bi.sd == nil {
1087 if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
1088 bi.sd = distinctNewQueries[idx]
1089 } else {
1090 sd := &pgconn.StatementDescription{
1091 SQL: bi.SQL,
1092 }
1093 distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
1094 distinctNewQueries = append(distinctNewQueries, sd)
1095 bi.sd = sd
1096 }
1097 }
1098 }
1099
1100 return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil)
1101 }
1102
1103 func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
1104 pipeline := c.pgConn.StartPipeline(ctx)
1105 defer func() {
1106 if pbr != nil && pbr.err != nil {
1107 pipeline.Close()
1108 }
1109 }()
1110
1111
1112 if len(distinctNewQueries) > 0 {
1113 for _, sd := range distinctNewQueries {
1114 pipeline.SendPrepare(sd.Name, sd.SQL, nil)
1115 }
1116
1117 err := pipeline.Sync()
1118 if err != nil {
1119 return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
1120 }
1121
1122 for _, sd := range distinctNewQueries {
1123 results, err := pipeline.GetResults()
1124 if err != nil {
1125 return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
1126 }
1127
1128 resultSD, ok := results.(*pgconn.StatementDescription)
1129 if !ok {
1130 return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true}
1131 }
1132
1133
1134 sd.ParamOIDs = resultSD.ParamOIDs
1135 sd.Fields = resultSD.Fields
1136 }
1137
1138 results, err := pipeline.GetResults()
1139 if err != nil {
1140 return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
1141 }
1142
1143 _, ok := results.(*pgconn.PipelineSync)
1144 if !ok {
1145 return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true}
1146 }
1147 }
1148
1149
1150 if sdCache != nil {
1151 for _, sd := range distinctNewQueries {
1152 sdCache.Put(sd)
1153 }
1154 }
1155
1156
1157 for _, bi := range b.QueuedQueries {
1158 err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
1159 if err != nil {
1160
1161 err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
1162 return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
1163 }
1164
1165 if bi.sd.Name == "" {
1166 pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
1167 } else {
1168 pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
1169 }
1170 }
1171
1172 err := pipeline.Sync()
1173 if err != nil {
1174 return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
1175 }
1176
1177 return &pipelineBatchResults{
1178 ctx: ctx,
1179 conn: c,
1180 pipeline: pipeline,
1181 b: b,
1182 }
1183 }
1184
1185 func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
1186 if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
1187 return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
1188 }
1189
1190 if c.pgConn.ParameterStatus("client_encoding") != "UTF8" {
1191 return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
1192 }
1193
1194 var err error
1195 valueArgs := make([]any, len(args))
1196 for i, a := range args {
1197 valueArgs[i], err = convertSimpleArgument(c.typeMap, a)
1198 if err != nil {
1199 return "", err
1200 }
1201 }
1202
1203 return sanitize.SanitizeSQL(sql, valueArgs...)
1204 }
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215 func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) {
1216 var oid uint32
1217
1218 err := c.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
1219 if err != nil {
1220 return nil, err
1221 }
1222
1223 var typtype string
1224 var typbasetype uint32
1225
1226 err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype)
1227 if err != nil {
1228 return nil, err
1229 }
1230
1231 switch typtype {
1232 case "b":
1233 elementOID, err := c.getArrayElementOID(ctx, oid)
1234 if err != nil {
1235 return nil, err
1236 }
1237
1238 dt, ok := c.TypeMap().TypeForOID(elementOID)
1239 if !ok {
1240 return nil, errors.New("array element OID not registered")
1241 }
1242
1243 return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}, nil
1244 case "c":
1245 fields, err := c.getCompositeFields(ctx, oid)
1246 if err != nil {
1247 return nil, err
1248 }
1249
1250 return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
1251 case "d":
1252 dt, ok := c.TypeMap().TypeForOID(typbasetype)
1253 if !ok {
1254 return nil, errors.New("domain base type OID not registered")
1255 }
1256
1257 return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil
1258 case "e":
1259 return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
1260 case "r":
1261 elementOID, err := c.getRangeElementOID(ctx, oid)
1262 if err != nil {
1263 return nil, err
1264 }
1265
1266 dt, ok := c.TypeMap().TypeForOID(elementOID)
1267 if !ok {
1268 return nil, errors.New("range element OID not registered")
1269 }
1270
1271 return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}}, nil
1272 case "m":
1273 elementOID, err := c.getMultiRangeElementOID(ctx, oid)
1274 if err != nil {
1275 return nil, err
1276 }
1277
1278 dt, ok := c.TypeMap().TypeForOID(elementOID)
1279 if !ok {
1280 return nil, errors.New("multirange element OID not registered")
1281 }
1282
1283 return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}, nil
1284 default:
1285 return &pgtype.Type{}, errors.New("unknown typtype")
1286 }
1287 }
1288
1289 func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, error) {
1290 var typelem uint32
1291
1292 err := c.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
1293 if err != nil {
1294 return 0, err
1295 }
1296
1297 return typelem, nil
1298 }
1299
1300 func (c *Conn) getRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
1301 var typelem uint32
1302
1303 err := c.QueryRow(ctx, "select rngsubtype from pg_range where rngtypid=$1", oid).Scan(&typelem)
1304 if err != nil {
1305 return 0, err
1306 }
1307
1308 return typelem, nil
1309 }
1310
1311 func (c *Conn) getMultiRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
1312 var typelem uint32
1313
1314 err := c.QueryRow(ctx, "select rngtypid from pg_range where rngmultitypid=$1", oid).Scan(&typelem)
1315 if err != nil {
1316 return 0, err
1317 }
1318
1319 return typelem, nil
1320 }
1321
1322 func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) {
1323 var typrelid uint32
1324
1325 err := c.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
1326 if err != nil {
1327 return nil, err
1328 }
1329
1330 var fields []pgtype.CompositeCodecField
1331 var fieldName string
1332 var fieldOID uint32
1333 rows, _ := c.Query(ctx, `select attname, atttypid
1334 from pg_attribute
1335 where attrelid=$1
1336 and not attisdropped
1337 and attnum > 0
1338 order by attnum`,
1339 typrelid,
1340 )
1341 _, err = ForEachRow(rows, []any{&fieldName, &fieldOID}, func() error {
1342 dt, ok := c.TypeMap().TypeForOID(fieldOID)
1343 if !ok {
1344 return fmt.Errorf("unknown composite type field OID: %v", fieldOID)
1345 }
1346 fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
1347 return nil
1348 })
1349 if err != nil {
1350 return nil, err
1351 }
1352
1353 return fields, nil
1354 }
1355
1356 func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
1357 if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' {
1358 return nil
1359 }
1360
1361 if c.descriptionCache != nil {
1362 c.descriptionCache.RemoveInvalidated()
1363 }
1364
1365 var invalidatedStatements []*pgconn.StatementDescription
1366 if c.statementCache != nil {
1367 invalidatedStatements = c.statementCache.GetInvalidated()
1368 }
1369
1370 if len(invalidatedStatements) == 0 {
1371 return nil
1372 }
1373
1374 pipeline := c.pgConn.StartPipeline(ctx)
1375 defer pipeline.Close()
1376
1377 for _, sd := range invalidatedStatements {
1378 pipeline.SendDeallocate(sd.Name)
1379 }
1380
1381 err := pipeline.Sync()
1382 if err != nil {
1383 return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
1384 }
1385
1386 err = pipeline.Close()
1387 if err != nil {
1388 return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
1389 }
1390
1391 c.statementCache.RemoveInvalidated()
1392 for _, sd := range invalidatedStatements {
1393 delete(c.preparedStatements, sd.Name)
1394 }
1395
1396 return nil
1397 }
1398
View as plain text