1 package pgx
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "io"
8
9 "github.com/jackc/pgx/v5/internal/pgio"
10 "github.com/jackc/pgx/v5/pgconn"
11 )
12
13
14
15 func CopyFromRows(rows [][]any) CopyFromSource {
16 return ©FromRows{rows: rows, idx: -1}
17 }
18
19 type copyFromRows struct {
20 rows [][]any
21 idx int
22 }
23
24 func (ctr *copyFromRows) Next() bool {
25 ctr.idx++
26 return ctr.idx < len(ctr.rows)
27 }
28
29 func (ctr *copyFromRows) Values() ([]any, error) {
30 return ctr.rows[ctr.idx], nil
31 }
32
33 func (ctr *copyFromRows) Err() error {
34 return nil
35 }
36
37
38
39 func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource {
40 return ©FromSlice{next: next, idx: -1, len: length}
41 }
42
43 type copyFromSlice struct {
44 next func(int) ([]any, error)
45 idx int
46 len int
47 err error
48 }
49
50 func (cts *copyFromSlice) Next() bool {
51 cts.idx++
52 return cts.idx < cts.len
53 }
54
55 func (cts *copyFromSlice) Values() ([]any, error) {
56 values, err := cts.next(cts.idx)
57 if err != nil {
58 cts.err = err
59 }
60 return values, err
61 }
62
63 func (cts *copyFromSlice) Err() error {
64 return cts.err
65 }
66
67
68
69
70 func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
71 return ©FromFunc{next: nxtf}
72 }
73
74 type copyFromFunc struct {
75 next func() ([]any, error)
76 valueRow []any
77 err error
78 }
79
80 func (g *copyFromFunc) Next() bool {
81 g.valueRow, g.err = g.next()
82
83 return g.valueRow != nil && g.err == nil
84 }
85
86 func (g *copyFromFunc) Values() ([]any, error) {
87 return g.valueRow, g.err
88 }
89
90 func (g *copyFromFunc) Err() error {
91 return g.err
92 }
93
94
95 type CopyFromSource interface {
96
97
98
99 Next() bool
100
101
102 Values() ([]any, error)
103
104
105
106 Err() error
107 }
108
109 type copyFrom struct {
110 conn *Conn
111 tableName Identifier
112 columnNames []string
113 rowSrc CopyFromSource
114 readerErrChan chan error
115 mode QueryExecMode
116 }
117
118 func (ct *copyFrom) run(ctx context.Context) (int64, error) {
119 if ct.conn.copyFromTracer != nil {
120 ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{
121 TableName: ct.tableName,
122 ColumnNames: ct.columnNames,
123 })
124 }
125
126 quotedTableName := ct.tableName.Sanitize()
127 cbuf := &bytes.Buffer{}
128 for i, cn := range ct.columnNames {
129 if i != 0 {
130 cbuf.WriteString(", ")
131 }
132 cbuf.WriteString(quoteIdentifier(cn))
133 }
134 quotedColumnNames := cbuf.String()
135
136 var sd *pgconn.StatementDescription
137 switch ct.mode {
138 case QueryExecModeExec, QueryExecModeSimpleProtocol:
139
140
141
142
143
144
145 ct.mode = QueryExecModeDescribeExec
146 fallthrough
147 case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
148 var err error
149 sd, err = ct.conn.getStatementDescription(
150 ctx,
151 ct.mode,
152 fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
153 )
154 if err != nil {
155 return 0, fmt.Errorf("statement description failed: %w", err)
156 }
157 default:
158 return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
159 }
160
161 r, w := io.Pipe()
162 doneChan := make(chan struct{})
163
164 go func() {
165 defer close(doneChan)
166
167
168 buf := ct.conn.wbuf
169
170 buf = append(buf, "PGCOPY\n\377\r\n\000"...)
171 buf = pgio.AppendInt32(buf, 0)
172 buf = pgio.AppendInt32(buf, 0)
173
174 moreRows := true
175 for moreRows {
176 var err error
177 moreRows, buf, err = ct.buildCopyBuf(buf, sd)
178 if err != nil {
179 w.CloseWithError(err)
180 return
181 }
182
183 if ct.rowSrc.Err() != nil {
184 w.CloseWithError(ct.rowSrc.Err())
185 return
186 }
187
188 if len(buf) > 0 {
189 _, err = w.Write(buf)
190 if err != nil {
191 w.Close()
192 return
193 }
194 }
195
196 buf = buf[:0]
197 }
198
199 w.Close()
200 }()
201
202 commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
203
204 r.Close()
205 <-doneChan
206
207 if ct.conn.copyFromTracer != nil {
208 ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{
209 CommandTag: commandTag,
210 Err: err,
211 })
212 }
213
214 return commandTag.RowsAffected(), err
215 }
216
217 func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
218 const sendBufSize = 65536 - 5
219 lastBufLen := 0
220 largestRowLen := 0
221
222 for ct.rowSrc.Next() {
223 lastBufLen = len(buf)
224
225 values, err := ct.rowSrc.Values()
226 if err != nil {
227 return false, nil, err
228 }
229 if len(values) != len(ct.columnNames) {
230 return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
231 }
232
233 buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
234 for i, val := range values {
235 buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
236 if err != nil {
237 return false, nil, err
238 }
239 }
240
241 rowLen := len(buf) - lastBufLen
242 if rowLen > largestRowLen {
243 largestRowLen = rowLen
244 }
245
246
247
248
249 if len(buf) > sendBufSize-largestRowLen {
250 return true, buf, nil
251 }
252 }
253
254 return false, buf, nil
255 }
256
257
258
259
260
261
262
263
264
265 func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
266 ct := ©From{
267 conn: c,
268 tableName: tableName,
269 columnNames: columnNames,
270 rowSrc: rowSrc,
271 readerErrChan: make(chan error),
272 mode: c.config.DefaultQueryExecMode,
273 }
274
275 return ct.run(ctx)
276 }
277
View as plain text