1 package pgxpool
2
3 import (
4 "context"
5 "fmt"
6 "math/rand"
7 "runtime"
8 "strconv"
9 "sync"
10 "sync/atomic"
11 "time"
12
13 "github.com/jackc/pgx/v5"
14 "github.com/jackc/pgx/v5/pgconn"
15 "github.com/jackc/puddle/v2"
16 )
17
18 var defaultMaxConns = int32(4)
19 var defaultMinConns = int32(0)
20 var defaultMaxConnLifetime = time.Hour
21 var defaultMaxConnIdleTime = time.Minute * 30
22 var defaultHealthCheckPeriod = time.Minute
23
24 type connResource struct {
25 conn *pgx.Conn
26 conns []Conn
27 poolRows []poolRow
28 poolRowss []poolRows
29 maxAgeTime time.Time
30 }
31
32 func (cr *connResource) getConn(p *Pool, res *puddle.Resource[*connResource]) *Conn {
33 if len(cr.conns) == 0 {
34 cr.conns = make([]Conn, 128)
35 }
36
37 c := &cr.conns[len(cr.conns)-1]
38 cr.conns = cr.conns[0 : len(cr.conns)-1]
39
40 c.res = res
41 c.p = p
42
43 return c
44 }
45
46 func (cr *connResource) getPoolRow(c *Conn, r pgx.Row) *poolRow {
47 if len(cr.poolRows) == 0 {
48 cr.poolRows = make([]poolRow, 128)
49 }
50
51 pr := &cr.poolRows[len(cr.poolRows)-1]
52 cr.poolRows = cr.poolRows[0 : len(cr.poolRows)-1]
53
54 pr.c = c
55 pr.r = r
56
57 return pr
58 }
59
60 func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows {
61 if len(cr.poolRowss) == 0 {
62 cr.poolRowss = make([]poolRows, 128)
63 }
64
65 pr := &cr.poolRowss[len(cr.poolRowss)-1]
66 cr.poolRowss = cr.poolRowss[0 : len(cr.poolRowss)-1]
67
68 pr.c = c
69 pr.r = r
70
71 return pr
72 }
73
74
75 type Pool struct {
76
77
78 newConnsCount int64
79 lifetimeDestroyCount int64
80 idleDestroyCount int64
81
82 p *puddle.Pool[*connResource]
83 config *Config
84 beforeConnect func(context.Context, *pgx.ConnConfig) error
85 afterConnect func(context.Context, *pgx.Conn) error
86 beforeAcquire func(context.Context, *pgx.Conn) bool
87 afterRelease func(*pgx.Conn) bool
88 beforeClose func(*pgx.Conn)
89 minConns int32
90 maxConns int32
91 maxConnLifetime time.Duration
92 maxConnLifetimeJitter time.Duration
93 maxConnIdleTime time.Duration
94 healthCheckPeriod time.Duration
95
96 healthCheckChan chan struct{}
97
98 closeOnce sync.Once
99 closeChan chan struct{}
100 }
101
102
103
104 type Config struct {
105 ConnConfig *pgx.ConnConfig
106
107
108
109 BeforeConnect func(context.Context, *pgx.ConnConfig) error
110
111
112 AfterConnect func(context.Context, *pgx.Conn) error
113
114
115
116
117 BeforeAcquire func(context.Context, *pgx.Conn) bool
118
119
120
121 AfterRelease func(*pgx.Conn) bool
122
123
124 BeforeClose func(*pgx.Conn)
125
126
127 MaxConnLifetime time.Duration
128
129
130
131 MaxConnLifetimeJitter time.Duration
132
133
134 MaxConnIdleTime time.Duration
135
136
137 MaxConns int32
138
139
140
141
142 MinConns int32
143
144
145 HealthCheckPeriod time.Duration
146
147 createdByParseConfig bool
148 }
149
150
151
152
153 func (c *Config) Copy() *Config {
154 newConfig := new(Config)
155 *newConfig = *c
156 newConfig.ConnConfig = c.ConnConfig.Copy()
157 return newConfig
158 }
159
160
161 func (c *Config) ConnString() string { return c.ConnConfig.ConnString() }
162
163
164 func New(ctx context.Context, connString string) (*Pool, error) {
165 config, err := ParseConfig(connString)
166 if err != nil {
167 return nil, err
168 }
169
170 return NewWithConfig(ctx, config)
171 }
172
173
174 func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
175
176
177 if !config.createdByParseConfig {
178 panic("config must be created by ParseConfig")
179 }
180
181 p := &Pool{
182 config: config,
183 beforeConnect: config.BeforeConnect,
184 afterConnect: config.AfterConnect,
185 beforeAcquire: config.BeforeAcquire,
186 afterRelease: config.AfterRelease,
187 beforeClose: config.BeforeClose,
188 minConns: config.MinConns,
189 maxConns: config.MaxConns,
190 maxConnLifetime: config.MaxConnLifetime,
191 maxConnLifetimeJitter: config.MaxConnLifetimeJitter,
192 maxConnIdleTime: config.MaxConnIdleTime,
193 healthCheckPeriod: config.HealthCheckPeriod,
194 healthCheckChan: make(chan struct{}, 1),
195 closeChan: make(chan struct{}),
196 }
197
198 var err error
199 p.p, err = puddle.NewPool(
200 &puddle.Config[*connResource]{
201 Constructor: func(ctx context.Context) (*connResource, error) {
202 atomic.AddInt64(&p.newConnsCount, 1)
203 connConfig := p.config.ConnConfig.Copy()
204
205
206 if connConfig.ConnectTimeout <= 0 {
207 connConfig.ConnectTimeout = 2 * time.Minute
208 }
209
210 if p.beforeConnect != nil {
211 if err := p.beforeConnect(ctx, connConfig); err != nil {
212 return nil, err
213 }
214 }
215
216 conn, err := pgx.ConnectConfig(ctx, connConfig)
217 if err != nil {
218 return nil, err
219 }
220
221 if p.afterConnect != nil {
222 err = p.afterConnect(ctx, conn)
223 if err != nil {
224 conn.Close(ctx)
225 return nil, err
226 }
227 }
228
229 jitterSecs := rand.Float64() * config.MaxConnLifetimeJitter.Seconds()
230 maxAgeTime := time.Now().Add(config.MaxConnLifetime).Add(time.Duration(jitterSecs) * time.Second)
231
232 cr := &connResource{
233 conn: conn,
234 conns: make([]Conn, 64),
235 poolRows: make([]poolRow, 64),
236 poolRowss: make([]poolRows, 64),
237 maxAgeTime: maxAgeTime,
238 }
239
240 return cr, nil
241 },
242 Destructor: func(value *connResource) {
243 ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
244 conn := value.conn
245 if p.beforeClose != nil {
246 p.beforeClose(conn)
247 }
248 conn.Close(ctx)
249 select {
250 case <-conn.PgConn().CleanupDone():
251 case <-ctx.Done():
252 }
253 cancel()
254 },
255 MaxSize: config.MaxConns,
256 },
257 )
258 if err != nil {
259 return nil, err
260 }
261
262 go func() {
263 p.createIdleResources(ctx, int(p.minConns))
264 p.backgroundHealthCheck()
265 }()
266
267 return p, nil
268 }
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287 func ParseConfig(connString string) (*Config, error) {
288 connConfig, err := pgx.ParseConfig(connString)
289 if err != nil {
290 return nil, err
291 }
292
293 config := &Config{
294 ConnConfig: connConfig,
295 createdByParseConfig: true,
296 }
297
298 if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conns"]; ok {
299 delete(connConfig.Config.RuntimeParams, "pool_max_conns")
300 n, err := strconv.ParseInt(s, 10, 32)
301 if err != nil {
302 return nil, fmt.Errorf("cannot parse pool_max_conns: %w", err)
303 }
304 if n < 1 {
305 return nil, fmt.Errorf("pool_max_conns too small: %d", n)
306 }
307 config.MaxConns = int32(n)
308 } else {
309 config.MaxConns = defaultMaxConns
310 if numCPU := int32(runtime.NumCPU()); numCPU > config.MaxConns {
311 config.MaxConns = numCPU
312 }
313 }
314
315 if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_conns"]; ok {
316 delete(connConfig.Config.RuntimeParams, "pool_min_conns")
317 n, err := strconv.ParseInt(s, 10, 32)
318 if err != nil {
319 return nil, fmt.Errorf("cannot parse pool_min_conns: %w", err)
320 }
321 config.MinConns = int32(n)
322 } else {
323 config.MinConns = defaultMinConns
324 }
325
326 if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok {
327 delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime")
328 d, err := time.ParseDuration(s)
329 if err != nil {
330 return nil, fmt.Errorf("invalid pool_max_conn_lifetime: %w", err)
331 }
332 config.MaxConnLifetime = d
333 } else {
334 config.MaxConnLifetime = defaultMaxConnLifetime
335 }
336
337 if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_idle_time"]; ok {
338 delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time")
339 d, err := time.ParseDuration(s)
340 if err != nil {
341 return nil, fmt.Errorf("invalid pool_max_conn_idle_time: %w", err)
342 }
343 config.MaxConnIdleTime = d
344 } else {
345 config.MaxConnIdleTime = defaultMaxConnIdleTime
346 }
347
348 if s, ok := config.ConnConfig.Config.RuntimeParams["pool_health_check_period"]; ok {
349 delete(connConfig.Config.RuntimeParams, "pool_health_check_period")
350 d, err := time.ParseDuration(s)
351 if err != nil {
352 return nil, fmt.Errorf("invalid pool_health_check_period: %w", err)
353 }
354 config.HealthCheckPeriod = d
355 } else {
356 config.HealthCheckPeriod = defaultHealthCheckPeriod
357 }
358
359 if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime_jitter"]; ok {
360 delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter")
361 d, err := time.ParseDuration(s)
362 if err != nil {
363 return nil, fmt.Errorf("invalid pool_max_conn_lifetime_jitter: %w", err)
364 }
365 config.MaxConnLifetimeJitter = d
366 }
367
368 return config, nil
369 }
370
371
372
373 func (p *Pool) Close() {
374 p.closeOnce.Do(func() {
375 close(p.closeChan)
376 p.p.Close()
377 })
378 }
379
380 func (p *Pool) isExpired(res *puddle.Resource[*connResource]) bool {
381 return time.Now().After(res.Value().maxAgeTime)
382 }
383
384 func (p *Pool) triggerHealthCheck() {
385 go func() {
386
387
388 time.Sleep(500 * time.Millisecond)
389 select {
390 case p.healthCheckChan <- struct{}{}:
391 default:
392 }
393 }()
394 }
395
396 func (p *Pool) backgroundHealthCheck() {
397 ticker := time.NewTicker(p.healthCheckPeriod)
398 defer ticker.Stop()
399 for {
400 select {
401 case <-p.closeChan:
402 return
403 case <-p.healthCheckChan:
404 p.checkHealth()
405 case <-ticker.C:
406 p.checkHealth()
407 }
408 }
409 }
410
411 func (p *Pool) checkHealth() {
412 for {
413
414
415 if err := p.checkMinConns(); err != nil {
416
417 break
418 }
419 if !p.checkConnsHealth() {
420
421 break
422 }
423
424
425 select {
426 case <-p.closeChan:
427 return
428 case <-time.After(500 * time.Millisecond):
429 }
430 }
431 }
432
433
434
435 func (p *Pool) checkConnsHealth() bool {
436 var destroyed bool
437 totalConns := p.Stat().TotalConns()
438 resources := p.p.AcquireAllIdle()
439 for _, res := range resources {
440
441 if p.isExpired(res) && totalConns >= p.minConns {
442 atomic.AddInt64(&p.lifetimeDestroyCount, 1)
443 res.Destroy()
444 destroyed = true
445
446 totalConns--
447 } else if res.IdleDuration() > p.maxConnIdleTime && totalConns > p.minConns {
448 atomic.AddInt64(&p.idleDestroyCount, 1)
449 res.Destroy()
450 destroyed = true
451
452 totalConns--
453 } else {
454 res.ReleaseUnused()
455 }
456 }
457 return destroyed
458 }
459
460 func (p *Pool) checkMinConns() error {
461
462
463
464 toCreate := p.minConns - p.Stat().TotalConns()
465 if toCreate > 0 {
466 return p.createIdleResources(context.Background(), int(toCreate))
467 }
468 return nil
469 }
470
471 func (p *Pool) createIdleResources(parentCtx context.Context, targetResources int) error {
472 ctx, cancel := context.WithCancel(parentCtx)
473 defer cancel()
474
475 errs := make(chan error, targetResources)
476
477 for i := 0; i < targetResources; i++ {
478 go func() {
479 err := p.p.CreateResource(ctx)
480
481 if err == puddle.ErrNotAvailable {
482 err = nil
483 }
484 errs <- err
485 }()
486 }
487
488 var firstError error
489 for i := 0; i < targetResources; i++ {
490 err := <-errs
491 if err != nil && firstError == nil {
492 cancel()
493 firstError = err
494 }
495 }
496
497 return firstError
498 }
499
500
501 func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
502 for {
503 res, err := p.p.Acquire(ctx)
504 if err != nil {
505 return nil, err
506 }
507
508 cr := res.Value()
509
510 if res.IdleDuration() > time.Second {
511 err := cr.conn.Ping(ctx)
512 if err != nil {
513 res.Destroy()
514 continue
515 }
516 }
517
518 if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
519 return cr.getConn(p, res), nil
520 }
521
522 res.Destroy()
523 }
524 }
525
526
527
528
529 func (p *Pool) AcquireFunc(ctx context.Context, f func(*Conn) error) error {
530 conn, err := p.Acquire(ctx)
531 if err != nil {
532 return err
533 }
534 defer conn.Release()
535
536 return f(conn)
537 }
538
539
540
541 func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn {
542 resources := p.p.AcquireAllIdle()
543 conns := make([]*Conn, 0, len(resources))
544 for _, res := range resources {
545 cr := res.Value()
546 if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
547 conns = append(conns, cr.getConn(p, res))
548 } else {
549 res.Destroy()
550 }
551 }
552
553 return conns
554 }
555
556
557
558
559
560
561 func (p *Pool) Reset() {
562 p.p.Reset()
563 }
564
565
566 func (p *Pool) Config() *Config { return p.config.Copy() }
567
568
569 func (p *Pool) Stat() *Stat {
570 return &Stat{
571 s: p.p.Stat(),
572 newConnsCount: atomic.LoadInt64(&p.newConnsCount),
573 lifetimeDestroyCount: atomic.LoadInt64(&p.lifetimeDestroyCount),
574 idleDestroyCount: atomic.LoadInt64(&p.idleDestroyCount),
575 }
576 }
577
578
579
580
581
582 func (p *Pool) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
583 c, err := p.Acquire(ctx)
584 if err != nil {
585 return pgconn.CommandTag{}, err
586 }
587 defer c.Release()
588
589 return c.Exec(ctx, sql, arguments...)
590 }
591
592
593
594
595
596
597
598
599
600
601
602 func (p *Pool) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
603 c, err := p.Acquire(ctx)
604 if err != nil {
605 return errRows{err: err}, err
606 }
607
608 rows, err := c.Query(ctx, sql, args...)
609 if err != nil {
610 c.Release()
611 return errRows{err: err}, err
612 }
613
614 return c.getPoolRows(rows), nil
615 }
616
617
618
619
620
621
622
623
624
625
626
627
628
629 func (p *Pool) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
630 c, err := p.Acquire(ctx)
631 if err != nil {
632 return errRow{err: err}
633 }
634
635 row := c.QueryRow(ctx, sql, args...)
636 return c.getPoolRow(row)
637 }
638
639 func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
640 c, err := p.Acquire(ctx)
641 if err != nil {
642 return errBatchResults{err: err}
643 }
644
645 br := c.SendBatch(ctx, b)
646 return &poolBatchResults{br: br, c: c}
647 }
648
649
650
651
652
653 func (p *Pool) Begin(ctx context.Context) (pgx.Tx, error) {
654 return p.BeginTx(ctx, pgx.TxOptions{})
655 }
656
657
658
659
660
661 func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) {
662 c, err := p.Acquire(ctx)
663 if err != nil {
664 return nil, err
665 }
666
667 t, err := c.BeginTx(ctx, txOptions)
668 if err != nil {
669 c.Release()
670 return nil, err
671 }
672
673 return &Tx{t: t, c: c}, nil
674 }
675
676 func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
677 c, err := p.Acquire(ctx)
678 if err != nil {
679 return 0, err
680 }
681 defer c.Release()
682
683 return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc)
684 }
685
686
687
688 func (p *Pool) Ping(ctx context.Context) error {
689 c, err := p.Acquire(ctx)
690 if err != nil {
691 return err
692 }
693 defer c.Release()
694 return c.Ping(ctx)
695 }
696
View as plain text