1 package pgx
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "io"
8 "time"
9
10 "github.com/jackc/pgconn"
11 "github.com/jackc/pgio"
12 )
13
14
15
16 func CopyFromRows(rows [][]interface{}) CopyFromSource {
17 return ©FromRows{rows: rows, idx: -1}
18 }
19
20 type copyFromRows struct {
21 rows [][]interface{}
22 idx int
23 }
24
25 func (ctr *copyFromRows) Next() bool {
26 ctr.idx++
27 return ctr.idx < len(ctr.rows)
28 }
29
30 func (ctr *copyFromRows) Values() ([]interface{}, error) {
31 return ctr.rows[ctr.idx], nil
32 }
33
34 func (ctr *copyFromRows) Err() error {
35 return nil
36 }
37
38
39
40 func CopyFromSlice(length int, next func(int) ([]interface{}, error)) CopyFromSource {
41 return ©FromSlice{next: next, idx: -1, len: length}
42 }
43
44 type copyFromSlice struct {
45 next func(int) ([]interface{}, error)
46 idx int
47 len int
48 err error
49 }
50
51 func (cts *copyFromSlice) Next() bool {
52 cts.idx++
53 return cts.idx < cts.len
54 }
55
56 func (cts *copyFromSlice) Values() ([]interface{}, error) {
57 values, err := cts.next(cts.idx)
58 if err != nil {
59 cts.err = err
60 }
61 return values, err
62 }
63
64 func (cts *copyFromSlice) Err() error {
65 return cts.err
66 }
67
68
69 type CopyFromSource interface {
70
71
72
73 Next() bool
74
75
76 Values() ([]interface{}, error)
77
78
79
80 Err() error
81 }
82
83 type copyFrom struct {
84 conn *Conn
85 tableName Identifier
86 columnNames []string
87 rowSrc CopyFromSource
88 readerErrChan chan error
89 }
90
91 func (ct *copyFrom) run(ctx context.Context) (int64, error) {
92 quotedTableName := ct.tableName.Sanitize()
93 cbuf := &bytes.Buffer{}
94 for i, cn := range ct.columnNames {
95 if i != 0 {
96 cbuf.WriteString(", ")
97 }
98 cbuf.WriteString(quoteIdentifier(cn))
99 }
100 quotedColumnNames := cbuf.String()
101
102 sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
103 if err != nil {
104 return 0, err
105 }
106
107 r, w := io.Pipe()
108 doneChan := make(chan struct{})
109
110 go func() {
111 defer close(doneChan)
112
113
114 buf := ct.conn.wbuf
115
116 buf = append(buf, "PGCOPY\n\377\r\n\000"...)
117 buf = pgio.AppendInt32(buf, 0)
118 buf = pgio.AppendInt32(buf, 0)
119
120 moreRows := true
121 for moreRows {
122 var err error
123 moreRows, buf, err = ct.buildCopyBuf(buf, sd)
124 if err != nil {
125 w.CloseWithError(err)
126 return
127 }
128
129 if ct.rowSrc.Err() != nil {
130 w.CloseWithError(ct.rowSrc.Err())
131 return
132 }
133
134 if len(buf) > 0 {
135 _, err = w.Write(buf)
136 if err != nil {
137 w.Close()
138 return
139 }
140 }
141
142 buf = buf[:0]
143 }
144
145 w.Close()
146 }()
147
148 startTime := time.Now()
149
150 commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
151
152 r.Close()
153 <-doneChan
154
155 rowsAffected := commandTag.RowsAffected()
156 endTime := time.Now()
157 if err == nil {
158 if ct.conn.shouldLog(LogLevelInfo) {
159 ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]interface{}{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected})
160 }
161 } else if ct.conn.shouldLog(LogLevelError) {
162 ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)})
163 }
164
165 return rowsAffected, err
166 }
167
168 func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
169
170 for ct.rowSrc.Next() {
171 values, err := ct.rowSrc.Values()
172 if err != nil {
173 return false, nil, err
174 }
175 if len(values) != len(ct.columnNames) {
176 return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
177 }
178
179 buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
180 for i, val := range values {
181 buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val)
182 if err != nil {
183 return false, nil, err
184 }
185 }
186
187 if len(buf) > 65536 {
188 return true, buf, nil
189 }
190 }
191
192 return false, buf, nil
193 }
194
195
196
197
198
199
200
201 func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
202 ct := ©From{
203 conn: c,
204 tableName: tableName,
205 columnNames: columnNames,
206 rowSrc: rowSrc,
207 readerErrChan: make(chan error),
208 }
209
210 return ct.run(ctx)
211 }
212
View as plain text