1 package pq
2
3 import (
4 "context"
5 "database/sql"
6 "database/sql/driver"
7 "fmt"
8 "io"
9 "io/ioutil"
10 "time"
11 )
12
13 const (
14 watchCancelDialContextTimeout = time.Second * 10
15 )
16
17
18 func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
19 list := make([]driver.Value, len(args))
20 for i, nv := range args {
21 list[i] = nv.Value
22 }
23 finish := cn.watchCancel(ctx)
24 r, err := cn.query(query, list)
25 if err != nil {
26 if finish != nil {
27 finish()
28 }
29 return nil, err
30 }
31 r.finish = finish
32 return r, nil
33 }
34
35
36 func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
37 list := make([]driver.Value, len(args))
38 for i, nv := range args {
39 list[i] = nv.Value
40 }
41
42 if finish := cn.watchCancel(ctx); finish != nil {
43 defer finish()
44 }
45
46 return cn.Exec(query, list)
47 }
48
49
50 func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
51 if finish := cn.watchCancel(ctx); finish != nil {
52 defer finish()
53 }
54 return cn.Prepare(query)
55 }
56
57
58 func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
59 var mode string
60
61 switch sql.IsolationLevel(opts.Isolation) {
62 case sql.LevelDefault:
63
64 case sql.LevelReadUncommitted:
65 mode = " ISOLATION LEVEL READ UNCOMMITTED"
66 case sql.LevelReadCommitted:
67 mode = " ISOLATION LEVEL READ COMMITTED"
68 case sql.LevelRepeatableRead:
69 mode = " ISOLATION LEVEL REPEATABLE READ"
70 case sql.LevelSerializable:
71 mode = " ISOLATION LEVEL SERIALIZABLE"
72 default:
73 return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation)
74 }
75
76 if opts.ReadOnly {
77 mode += " READ ONLY"
78 } else {
79 mode += " READ WRITE"
80 }
81
82 tx, err := cn.begin(mode)
83 if err != nil {
84 return nil, err
85 }
86 cn.txnFinish = cn.watchCancel(ctx)
87 return tx, nil
88 }
89
90 func (cn *conn) Ping(ctx context.Context) error {
91 if finish := cn.watchCancel(ctx); finish != nil {
92 defer finish()
93 }
94 rows, err := cn.simpleQuery(";")
95 if err != nil {
96 return driver.ErrBadConn
97 }
98 rows.Close()
99 return nil
100 }
101
102 func (cn *conn) watchCancel(ctx context.Context) func() {
103 if done := ctx.Done(); done != nil {
104 finished := make(chan struct{}, 1)
105 go func() {
106 select {
107 case <-done:
108 select {
109 case finished <- struct{}{}:
110 default:
111
112
113 return
114 }
115
116
117 cn.err.set(ctx.Err())
118
119
120
121
122
123 ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
124 defer cancel()
125
126 _ = cn.cancel(ctxCancel)
127 case <-finished:
128 }
129 }()
130 return func() {
131 select {
132 case <-finished:
133 cn.err.set(ctx.Err())
134 cn.Close()
135 case finished <- struct{}{}:
136 }
137 }
138 }
139 return nil
140 }
141
142 func (cn *conn) cancel(ctx context.Context) error {
143
144
145
146
147 o := make(values)
148 for k, v := range cn.opts {
149 o[k] = v
150 }
151
152 c, err := dial(ctx, cn.dialer, o)
153 if err != nil {
154 return err
155 }
156 defer c.Close()
157
158 {
159 can := conn{
160 c: c,
161 }
162 err = can.ssl(o)
163 if err != nil {
164 return err
165 }
166
167 w := can.writeBuf(0)
168 w.int32(80877102)
169 w.int32(cn.processID)
170 w.int32(cn.secretKey)
171
172 if err := can.sendStartupPacket(w); err != nil {
173 return err
174 }
175 }
176
177
178 {
179 _, err := io.Copy(ioutil.Discard, c)
180 return err
181 }
182 }
183
184
185 func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
186 list := make([]driver.Value, len(args))
187 for i, nv := range args {
188 list[i] = nv.Value
189 }
190 finish := st.watchCancel(ctx)
191 r, err := st.query(list)
192 if err != nil {
193 if finish != nil {
194 finish()
195 }
196 return nil, err
197 }
198 r.finish = finish
199 return r, nil
200 }
201
202
203 func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
204 list := make([]driver.Value, len(args))
205 for i, nv := range args {
206 list[i] = nv.Value
207 }
208
209 if finish := st.watchCancel(ctx); finish != nil {
210 defer finish()
211 }
212
213 return st.Exec(list)
214 }
215
216
217 func (st *stmt) watchCancel(ctx context.Context) func() {
218 if done := ctx.Done(); done != nil {
219 finished := make(chan struct{})
220 go func() {
221 select {
222 case <-done:
223
224
225
226
227 ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
228 defer cancel()
229
230 _ = st.cancel(ctxCancel)
231 finished <- struct{}{}
232 case <-finished:
233 }
234 }()
235 return func() {
236 select {
237 case <-finished:
238 case finished <- struct{}{}:
239 }
240 }
241 }
242 return nil
243 }
244
245 func (st *stmt) cancel(ctx context.Context) error {
246 return st.cn.cancel(ctx)
247 }
248
View as plain text