1 package pq
2
3 import (
4 "bytes"
5 "context"
6 "database/sql/driver"
7 "encoding/binary"
8 "errors"
9 "fmt"
10 "sync"
11 )
12
13 var (
14 errCopyInClosed = errors.New("pq: copyin statement has already been closed")
15 errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
16 errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
17 errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
18 errCopyInProgress = errors.New("pq: COPY in progress")
19 )
20
21
22
23 func CopyIn(table string, columns ...string) string {
24 buffer := bytes.NewBufferString("COPY ")
25 BufferQuoteIdentifier(table, buffer)
26 buffer.WriteString(" (")
27 makeStmt(buffer, columns...)
28 return buffer.String()
29 }
30
31
32 func makeStmt(buffer *bytes.Buffer, columns ...string) {
33
34 for i, col := range columns {
35 if i != 0 {
36 buffer.WriteString(", ")
37 }
38 BufferQuoteIdentifier(col, buffer)
39 }
40 buffer.WriteString(") FROM STDIN")
41 }
42
43
44
45 func CopyInSchema(schema, table string, columns ...string) string {
46 buffer := bytes.NewBufferString("COPY ")
47 BufferQuoteIdentifier(schema, buffer)
48 buffer.WriteRune('.')
49 BufferQuoteIdentifier(table, buffer)
50 buffer.WriteString(" (")
51 makeStmt(buffer, columns...)
52 return buffer.String()
53 }
54
55 type copyin struct {
56 cn *conn
57 buffer []byte
58 rowData chan []byte
59 done chan bool
60
61 closed bool
62
63 mu struct {
64 sync.Mutex
65 err error
66 driver.Result
67 }
68 }
69
70 const ciBufferSize = 64 * 1024
71
72
73 const ciBufferFlushSize = 63 * 1024
74
75 func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
76 if !cn.isInTransaction() {
77 return nil, errCopyNotSupportedOutsideTxn
78 }
79
80 ci := ©in{
81 cn: cn,
82 buffer: make([]byte, 0, ciBufferSize),
83 rowData: make(chan []byte),
84 done: make(chan bool, 1),
85 }
86
87 ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
88
89 b := cn.writeBuf('Q')
90 b.string(q)
91 cn.send(b)
92
93 awaitCopyInResponse:
94 for {
95 t, r := cn.recv1()
96 switch t {
97 case 'G':
98 if r.byte() != 0 {
99 err = errBinaryCopyNotSupported
100 break awaitCopyInResponse
101 }
102 go ci.resploop()
103 return ci, nil
104 case 'H':
105 err = errCopyToNotSupported
106 break awaitCopyInResponse
107 case 'E':
108 err = parseError(r)
109 case 'Z':
110 if err == nil {
111 ci.setBad(driver.ErrBadConn)
112 errorf("unexpected ReadyForQuery in response to COPY")
113 }
114 cn.processReadyForQuery(r)
115 return nil, err
116 default:
117 ci.setBad(driver.ErrBadConn)
118 errorf("unknown response for copy query: %q", t)
119 }
120 }
121
122
123 b = cn.writeBuf('f')
124 b.string(err.Error())
125 cn.send(b)
126
127 for {
128 t, r := cn.recv1()
129 switch t {
130 case 'c', 'C', 'E':
131 case 'Z':
132
133 cn.processReadyForQuery(r)
134 return nil, err
135 default:
136 ci.setBad(driver.ErrBadConn)
137 errorf("unknown response for CopyFail: %q", t)
138 }
139 }
140 }
141
142 func (ci *copyin) flush(buf []byte) {
143
144 binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
145
146 _, err := ci.cn.c.Write(buf)
147 if err != nil {
148 panic(err)
149 }
150 }
151
152 func (ci *copyin) resploop() {
153 for {
154 var r readBuf
155 t, err := ci.cn.recvMessage(&r)
156 if err != nil {
157 ci.setBad(driver.ErrBadConn)
158 ci.setError(err)
159 ci.done <- true
160 return
161 }
162 switch t {
163 case 'C':
164
165 res, _ := ci.cn.parseComplete(r.string())
166 ci.setResult(res)
167 case 'N':
168 if n := ci.cn.noticeHandler; n != nil {
169 n(parseError(&r))
170 }
171 case 'Z':
172 ci.cn.processReadyForQuery(&r)
173 ci.done <- true
174 return
175 case 'E':
176 err := parseError(&r)
177 ci.setError(err)
178 default:
179 ci.setBad(driver.ErrBadConn)
180 ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
181 ci.done <- true
182 return
183 }
184 }
185 }
186
187 func (ci *copyin) setBad(err error) {
188 ci.cn.err.set(err)
189 }
190
191 func (ci *copyin) getBad() error {
192 return ci.cn.err.get()
193 }
194
195 func (ci *copyin) err() error {
196 ci.mu.Lock()
197 err := ci.mu.err
198 ci.mu.Unlock()
199 return err
200 }
201
202
203
204 func (ci *copyin) setError(err error) {
205 ci.mu.Lock()
206 if ci.mu.err == nil {
207 ci.mu.err = err
208 }
209 ci.mu.Unlock()
210 }
211
212 func (ci *copyin) setResult(result driver.Result) {
213 ci.mu.Lock()
214 ci.mu.Result = result
215 ci.mu.Unlock()
216 }
217
218 func (ci *copyin) getResult() driver.Result {
219 ci.mu.Lock()
220 result := ci.mu.Result
221 ci.mu.Unlock()
222 if result == nil {
223 return driver.RowsAffected(0)
224 }
225 return result
226 }
227
228 func (ci *copyin) NumInput() int {
229 return -1
230 }
231
232 func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
233 return nil, ErrNotSupported
234 }
235
236
237
238
239
240
241
242
243 func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
244 if ci.closed {
245 return nil, errCopyInClosed
246 }
247
248 if err := ci.getBad(); err != nil {
249 return nil, err
250 }
251 defer ci.cn.errRecover(&err)
252
253 if err := ci.err(); err != nil {
254 return nil, err
255 }
256
257 if len(v) == 0 {
258 if err := ci.Close(); err != nil {
259 return driver.RowsAffected(0), err
260 }
261
262 return ci.getResult(), nil
263 }
264
265 numValues := len(v)
266 for i, value := range v {
267 ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
268 if i < numValues-1 {
269 ci.buffer = append(ci.buffer, '\t')
270 }
271 }
272
273 ci.buffer = append(ci.buffer, '\n')
274
275 if len(ci.buffer) > ciBufferFlushSize {
276 ci.flush(ci.buffer)
277
278 ci.buffer = ci.buffer[:5]
279 }
280
281 return driver.RowsAffected(0), nil
282 }
283
284
285
286
287
288
289
290
291 func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) {
292 if ci.closed {
293 return nil, errCopyInClosed
294 }
295
296 if finish := ci.cn.watchCancel(ctx); finish != nil {
297 defer finish()
298 }
299
300 if err := ci.getBad(); err != nil {
301 return nil, err
302 }
303 defer ci.cn.errRecover(&err)
304
305 if err := ci.err(); err != nil {
306 return nil, err
307 }
308
309 ci.buffer = append(ci.buffer, []byte(line)...)
310 ci.buffer = append(ci.buffer, '\n')
311
312 if len(ci.buffer) > ciBufferFlushSize {
313 ci.flush(ci.buffer)
314
315 ci.buffer = ci.buffer[:5]
316 }
317
318 return driver.RowsAffected(0), nil
319 }
320
321 func (ci *copyin) Close() (err error) {
322 if ci.closed {
323 return nil
324 }
325 ci.closed = true
326
327 if err := ci.getBad(); err != nil {
328 return err
329 }
330 defer ci.cn.errRecover(&err)
331
332 if len(ci.buffer) > 0 {
333 ci.flush(ci.buffer)
334 }
335
336 err = ci.cn.sendSimpleMessage('c')
337 if err != nil {
338 return err
339 }
340
341 <-ci.done
342 ci.cn.inCopy = false
343
344 if err := ci.err(); err != nil {
345 return err
346 }
347 return nil
348 }
349
View as plain text