...

Source file src/github.com/jackc/pgx/v4/copy_from.go

Documentation: github.com/jackc/pgx/v4

     1  package pgx
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"time"
     9  
    10  	"github.com/jackc/pgconn"
    11  	"github.com/jackc/pgio"
    12  )
    13  
    14  // CopyFromRows returns a CopyFromSource interface over the provided rows slice
    15  // making it usable by *Conn.CopyFrom.
    16  func CopyFromRows(rows [][]interface{}) CopyFromSource {
    17  	return &copyFromRows{rows: rows, idx: -1}
    18  }
    19  
    20  type copyFromRows struct {
    21  	rows [][]interface{}
    22  	idx  int
    23  }
    24  
    25  func (ctr *copyFromRows) Next() bool {
    26  	ctr.idx++
    27  	return ctr.idx < len(ctr.rows)
    28  }
    29  
    30  func (ctr *copyFromRows) Values() ([]interface{}, error) {
    31  	return ctr.rows[ctr.idx], nil
    32  }
    33  
    34  func (ctr *copyFromRows) Err() error {
    35  	return nil
    36  }
    37  
    38  // CopyFromSlice returns a CopyFromSource interface over a dynamic func
    39  // making it usable by *Conn.CopyFrom.
    40  func CopyFromSlice(length int, next func(int) ([]interface{}, error)) CopyFromSource {
    41  	return &copyFromSlice{next: next, idx: -1, len: length}
    42  }
    43  
    44  type copyFromSlice struct {
    45  	next func(int) ([]interface{}, error)
    46  	idx  int
    47  	len  int
    48  	err  error
    49  }
    50  
    51  func (cts *copyFromSlice) Next() bool {
    52  	cts.idx++
    53  	return cts.idx < cts.len
    54  }
    55  
    56  func (cts *copyFromSlice) Values() ([]interface{}, error) {
    57  	values, err := cts.next(cts.idx)
    58  	if err != nil {
    59  		cts.err = err
    60  	}
    61  	return values, err
    62  }
    63  
    64  func (cts *copyFromSlice) Err() error {
    65  	return cts.err
    66  }
    67  
    68  // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
    69  type CopyFromSource interface {
    70  	// Next returns true if there is another row and makes the next row data
    71  	// available to Values(). When there are no more rows available or an error
    72  	// has occurred it returns false.
    73  	Next() bool
    74  
    75  	// Values returns the values for the current row.
    76  	Values() ([]interface{}, error)
    77  
    78  	// Err returns any error that has been encountered by the CopyFromSource. If
    79  	// this is not nil *Conn.CopyFrom will abort the copy.
    80  	Err() error
    81  }
    82  
    83  type copyFrom struct {
    84  	conn          *Conn
    85  	tableName     Identifier
    86  	columnNames   []string
    87  	rowSrc        CopyFromSource
    88  	readerErrChan chan error
    89  }
    90  
    91  func (ct *copyFrom) run(ctx context.Context) (int64, error) {
    92  	quotedTableName := ct.tableName.Sanitize()
    93  	cbuf := &bytes.Buffer{}
    94  	for i, cn := range ct.columnNames {
    95  		if i != 0 {
    96  			cbuf.WriteString(", ")
    97  		}
    98  		cbuf.WriteString(quoteIdentifier(cn))
    99  	}
   100  	quotedColumnNames := cbuf.String()
   101  
   102  	sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
   103  	if err != nil {
   104  		return 0, err
   105  	}
   106  
   107  	r, w := io.Pipe()
   108  	doneChan := make(chan struct{})
   109  
   110  	go func() {
   111  		defer close(doneChan)
   112  
   113  		// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
   114  		buf := ct.conn.wbuf
   115  
   116  		buf = append(buf, "PGCOPY\n\377\r\n\000"...)
   117  		buf = pgio.AppendInt32(buf, 0)
   118  		buf = pgio.AppendInt32(buf, 0)
   119  
   120  		moreRows := true
   121  		for moreRows {
   122  			var err error
   123  			moreRows, buf, err = ct.buildCopyBuf(buf, sd)
   124  			if err != nil {
   125  				w.CloseWithError(err)
   126  				return
   127  			}
   128  
   129  			if ct.rowSrc.Err() != nil {
   130  				w.CloseWithError(ct.rowSrc.Err())
   131  				return
   132  			}
   133  
   134  			if len(buf) > 0 {
   135  				_, err = w.Write(buf)
   136  				if err != nil {
   137  					w.Close()
   138  					return
   139  				}
   140  			}
   141  
   142  			buf = buf[:0]
   143  		}
   144  
   145  		w.Close()
   146  	}()
   147  
   148  	startTime := time.Now()
   149  
   150  	commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
   151  
   152  	r.Close()
   153  	<-doneChan
   154  
   155  	rowsAffected := commandTag.RowsAffected()
   156  	endTime := time.Now()
   157  	if err == nil {
   158  		if ct.conn.shouldLog(LogLevelInfo) {
   159  			ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]interface{}{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected})
   160  		}
   161  	} else if ct.conn.shouldLog(LogLevelError) {
   162  		ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)})
   163  	}
   164  
   165  	return rowsAffected, err
   166  }
   167  
   168  func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
   169  
   170  	for ct.rowSrc.Next() {
   171  		values, err := ct.rowSrc.Values()
   172  		if err != nil {
   173  			return false, nil, err
   174  		}
   175  		if len(values) != len(ct.columnNames) {
   176  			return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
   177  		}
   178  
   179  		buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
   180  		for i, val := range values {
   181  			buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val)
   182  			if err != nil {
   183  				return false, nil, err
   184  			}
   185  		}
   186  
   187  		if len(buf) > 65536 {
   188  			return true, buf, nil
   189  		}
   190  	}
   191  
   192  	return false, buf, nil
   193  }
   194  
   195  // CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
   196  // It returns the number of rows copied and an error.
   197  //
   198  // CopyFrom requires all values use the binary format. Almost all types
   199  // implemented by pgx use the binary format by default. Types implementing
   200  // Encoder can only be used if they encode to the binary format.
   201  func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
   202  	ct := &copyFrom{
   203  		conn:          c,
   204  		tableName:     tableName,
   205  		columnNames:   columnNames,
   206  		rowSrc:        rowSrc,
   207  		readerErrChan: make(chan error),
   208  	}
   209  
   210  	return ct.run(ctx)
   211  }
   212  

View as plain text