...
1 package stmtcache
2
3 import (
4 "container/list"
5 "context"
6 "fmt"
7 "sync/atomic"
8
9 "github.com/jackc/pgconn"
10 )
11
12 var lruCount uint64
13
14
15 type LRU struct {
16 conn *pgconn.PgConn
17 mode int
18 cap int
19 prepareCount int
20 m map[string]*list.Element
21 l *list.List
22 psNamePrefix string
23 stmtsToClear []string
24 }
25
26
27 func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
28 mustBeValidMode(mode)
29 mustBeValidCap(cap)
30
31 n := atomic.AddUint64(&lruCount, 1)
32
33 return &LRU{
34 conn: conn,
35 mode: mode,
36 cap: cap,
37 m: make(map[string]*list.Element),
38 l: list.New(),
39 psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
40 }
41 }
42
43
44 func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
45 if ctx != context.Background() {
46 select {
47 case <-ctx.Done():
48 return nil, ctx.Err()
49 default:
50 }
51 }
52
53
54 txStatus := c.conn.TxStatus()
55 if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
56 for _, stmt := range c.stmtsToClear {
57 err := c.clearStmt(ctx, stmt)
58 if err != nil {
59 return nil, err
60 }
61 }
62 }
63
64 if el, ok := c.m[sql]; ok {
65 c.l.MoveToFront(el)
66 return el.Value.(*pgconn.StatementDescription), nil
67 }
68
69 if c.l.Len() == c.cap {
70 err := c.removeOldest(ctx)
71 if err != nil {
72 return nil, err
73 }
74 }
75
76 psd, err := c.prepare(ctx, sql)
77 if err != nil {
78 return nil, err
79 }
80
81 el := c.l.PushFront(psd)
82 c.m[sql] = el
83
84 return psd, nil
85 }
86
87
88 func (c *LRU) Clear(ctx context.Context) error {
89 for c.l.Len() > 0 {
90 err := c.removeOldest(ctx)
91 if err != nil {
92 return err
93 }
94 }
95
96 return nil
97 }
98
99 func (c *LRU) StatementErrored(sql string, err error) {
100 pgErr, ok := err.(*pgconn.PgError)
101 if !ok {
102 return
103 }
104
105
106
107
108
109
110
111 possibleInvalidCachedPlanError := pgErr.Code == "0A000"
112 if possibleInvalidCachedPlanError {
113 c.stmtsToClear = append(c.stmtsToClear, sql)
114 }
115 }
116
117 func (c *LRU) clearStmt(ctx context.Context, sql string) error {
118 elem, inMap := c.m[sql]
119 if !inMap {
120
121
122 return nil
123 }
124
125 c.l.Remove(elem)
126
127 psd := elem.Value.(*pgconn.StatementDescription)
128 delete(c.m, psd.SQL)
129 if c.mode == ModePrepare {
130 return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
131 }
132 return nil
133 }
134
135
136 func (c *LRU) Len() int {
137 return c.l.Len()
138 }
139
140
141 func (c *LRU) Cap() int {
142 return c.cap
143 }
144
145
146 func (c *LRU) Mode() int {
147 return c.mode
148 }
149
150 func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
151 var name string
152 if c.mode == ModePrepare {
153 name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
154 c.prepareCount += 1
155 }
156
157 return c.conn.Prepare(ctx, name, sql, nil)
158 }
159
160 func (c *LRU) removeOldest(ctx context.Context) error {
161 oldest := c.l.Back()
162 c.l.Remove(oldest)
163 psd := oldest.Value.(*pgconn.StatementDescription)
164 delete(c.m, psd.SQL)
165 if c.mode == ModePrepare {
166 return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
167 }
168 return nil
169 }
170
View as plain text