...

Source file src/github.com/jackc/pgx/v5/copy_from.go

Documentation: github.com/jackc/pgx/v5

     1  package pgx
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  
     9  	"github.com/jackc/pgx/v5/internal/pgio"
    10  	"github.com/jackc/pgx/v5/pgconn"
    11  )
    12  
    13  // CopyFromRows returns a CopyFromSource interface over the provided rows slice
    14  // making it usable by *Conn.CopyFrom.
    15  func CopyFromRows(rows [][]any) CopyFromSource {
    16  	return &copyFromRows{rows: rows, idx: -1}
    17  }
    18  
    19  type copyFromRows struct {
    20  	rows [][]any
    21  	idx  int
    22  }
    23  
    24  func (ctr *copyFromRows) Next() bool {
    25  	ctr.idx++
    26  	return ctr.idx < len(ctr.rows)
    27  }
    28  
    29  func (ctr *copyFromRows) Values() ([]any, error) {
    30  	return ctr.rows[ctr.idx], nil
    31  }
    32  
    33  func (ctr *copyFromRows) Err() error {
    34  	return nil
    35  }
    36  
    37  // CopyFromSlice returns a CopyFromSource interface over a dynamic func
    38  // making it usable by *Conn.CopyFrom.
    39  func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource {
    40  	return &copyFromSlice{next: next, idx: -1, len: length}
    41  }
    42  
    43  type copyFromSlice struct {
    44  	next func(int) ([]any, error)
    45  	idx  int
    46  	len  int
    47  	err  error
    48  }
    49  
    50  func (cts *copyFromSlice) Next() bool {
    51  	cts.idx++
    52  	return cts.idx < cts.len
    53  }
    54  
    55  func (cts *copyFromSlice) Values() ([]any, error) {
    56  	values, err := cts.next(cts.idx)
    57  	if err != nil {
    58  		cts.err = err
    59  	}
    60  	return values, err
    61  }
    62  
    63  func (cts *copyFromSlice) Err() error {
    64  	return cts.err
    65  }
    66  
    67  // CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values.
    68  // nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil,
    69  // or it returns an error. If nxtf returns an error, the copy is aborted.
    70  func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
    71  	return &copyFromFunc{next: nxtf}
    72  }
    73  
    74  type copyFromFunc struct {
    75  	next     func() ([]any, error)
    76  	valueRow []any
    77  	err      error
    78  }
    79  
    80  func (g *copyFromFunc) Next() bool {
    81  	g.valueRow, g.err = g.next()
    82  	// only return true if valueRow exists and no error
    83  	return g.valueRow != nil && g.err == nil
    84  }
    85  
    86  func (g *copyFromFunc) Values() ([]any, error) {
    87  	return g.valueRow, g.err
    88  }
    89  
    90  func (g *copyFromFunc) Err() error {
    91  	return g.err
    92  }
    93  
    94  // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
    95  type CopyFromSource interface {
    96  	// Next returns true if there is another row and makes the next row data
    97  	// available to Values(). When there are no more rows available or an error
    98  	// has occurred it returns false.
    99  	Next() bool
   100  
   101  	// Values returns the values for the current row.
   102  	Values() ([]any, error)
   103  
   104  	// Err returns any error that has been encountered by the CopyFromSource. If
   105  	// this is not nil *Conn.CopyFrom will abort the copy.
   106  	Err() error
   107  }
   108  
   109  type copyFrom struct {
   110  	conn          *Conn
   111  	tableName     Identifier
   112  	columnNames   []string
   113  	rowSrc        CopyFromSource
   114  	readerErrChan chan error
   115  	mode          QueryExecMode
   116  }
   117  
   118  func (ct *copyFrom) run(ctx context.Context) (int64, error) {
   119  	if ct.conn.copyFromTracer != nil {
   120  		ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{
   121  			TableName:   ct.tableName,
   122  			ColumnNames: ct.columnNames,
   123  		})
   124  	}
   125  
   126  	quotedTableName := ct.tableName.Sanitize()
   127  	cbuf := &bytes.Buffer{}
   128  	for i, cn := range ct.columnNames {
   129  		if i != 0 {
   130  			cbuf.WriteString(", ")
   131  		}
   132  		cbuf.WriteString(quoteIdentifier(cn))
   133  	}
   134  	quotedColumnNames := cbuf.String()
   135  
   136  	var sd *pgconn.StatementDescription
   137  	switch ct.mode {
   138  	case QueryExecModeExec, QueryExecModeSimpleProtocol:
   139  		// These modes don't support the binary format. Before the inclusion of the
   140  		// QueryExecModes, Conn.Prepare was called on every COPY operation to get
   141  		// the OIDs. These prepared statements were not cached.
   142  		//
   143  		// Since that's the same behavior provided by QueryExecModeDescribeExec,
   144  		// we'll default to that mode.
   145  		ct.mode = QueryExecModeDescribeExec
   146  		fallthrough
   147  	case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
   148  		var err error
   149  		sd, err = ct.conn.getStatementDescription(
   150  			ctx,
   151  			ct.mode,
   152  			fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
   153  		)
   154  		if err != nil {
   155  			return 0, fmt.Errorf("statement description failed: %w", err)
   156  		}
   157  	default:
   158  		return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
   159  	}
   160  
   161  	r, w := io.Pipe()
   162  	doneChan := make(chan struct{})
   163  
   164  	go func() {
   165  		defer close(doneChan)
   166  
   167  		// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
   168  		buf := ct.conn.wbuf
   169  
   170  		buf = append(buf, "PGCOPY\n\377\r\n\000"...)
   171  		buf = pgio.AppendInt32(buf, 0)
   172  		buf = pgio.AppendInt32(buf, 0)
   173  
   174  		moreRows := true
   175  		for moreRows {
   176  			var err error
   177  			moreRows, buf, err = ct.buildCopyBuf(buf, sd)
   178  			if err != nil {
   179  				w.CloseWithError(err)
   180  				return
   181  			}
   182  
   183  			if ct.rowSrc.Err() != nil {
   184  				w.CloseWithError(ct.rowSrc.Err())
   185  				return
   186  			}
   187  
   188  			if len(buf) > 0 {
   189  				_, err = w.Write(buf)
   190  				if err != nil {
   191  					w.Close()
   192  					return
   193  				}
   194  			}
   195  
   196  			buf = buf[:0]
   197  		}
   198  
   199  		w.Close()
   200  	}()
   201  
   202  	commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
   203  
   204  	r.Close()
   205  	<-doneChan
   206  
   207  	if ct.conn.copyFromTracer != nil {
   208  		ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{
   209  			CommandTag: commandTag,
   210  			Err:        err,
   211  		})
   212  	}
   213  
   214  	return commandTag.RowsAffected(), err
   215  }
   216  
   217  func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
   218  	const sendBufSize = 65536 - 5 // The packet has a 5-byte header
   219  	lastBufLen := 0
   220  	largestRowLen := 0
   221  
   222  	for ct.rowSrc.Next() {
   223  		lastBufLen = len(buf)
   224  
   225  		values, err := ct.rowSrc.Values()
   226  		if err != nil {
   227  			return false, nil, err
   228  		}
   229  		if len(values) != len(ct.columnNames) {
   230  			return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
   231  		}
   232  
   233  		buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
   234  		for i, val := range values {
   235  			buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
   236  			if err != nil {
   237  				return false, nil, err
   238  			}
   239  		}
   240  
   241  		rowLen := len(buf) - lastBufLen
   242  		if rowLen > largestRowLen {
   243  			largestRowLen = rowLen
   244  		}
   245  
   246  		// Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of
   247  		// io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531
   248  		// 13, 65531, 13, 65531, 13.
   249  		if len(buf) > sendBufSize-largestRowLen {
   250  			return true, buf, nil
   251  		}
   252  	}
   253  
   254  	return false, buf, nil
   255  }
   256  
   257  // CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and
   258  // an error.
   259  //
   260  // CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered
   261  // for the type of each column. Almost all types implemented by pgx support the binary format.
   262  //
   263  // Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with
   264  // Conn.LoadType and pgtype.Map.RegisterType.
   265  func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
   266  	ct := &copyFrom{
   267  		conn:          c,
   268  		tableName:     tableName,
   269  		columnNames:   columnNames,
   270  		rowSrc:        rowSrc,
   271  		readerErrChan: make(chan error),
   272  		mode:          c.config.DefaultQueryExecMode,
   273  	}
   274  
   275  	return ct.run(ctx)
   276  }
   277  

View as plain text