...

Source file src/github.com/letsencrypt/boulder/db/multi.go

Documentation: github.com/letsencrypt/boulder/db

     1  package db
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  )
     8  
     9  // MultiInserter makes it easy to construct a
    10  // `INSERT INTO table (...) VALUES ... RETURNING id;`
    11  // query which inserts multiple rows into the same table. It can also execute
    12  // the resulting query.
    13  type MultiInserter struct {
    14  	// These are validated by the constructor as containing only characters
    15  	// that are allowed in an unquoted identifier.
    16  	// https://mariadb.com/kb/en/identifier-names/#unquoted
    17  	table           string
    18  	fields          []string
    19  	returningColumn string
    20  
    21  	values [][]interface{}
    22  }
    23  
    24  // NewMultiInserter creates a new MultiInserter, checking for reasonable table
    25  // name and list of fields. returningColumn is the name of a column to be used
    26  // in a `RETURNING xyz` clause at the end. If it is empty, no `RETURNING xyz`
    27  // clause is used. If returningColumn is present, it must refer to a column
    28  // that can be parsed into an int64.
    29  // Safety: `table`, `fields`, and `returningColumn` must contain only strings
    30  // that are known at compile time. They must not contain user-controlled
    31  // strings.
    32  func NewMultiInserter(table string, fields []string, returningColumn string) (*MultiInserter, error) {
    33  	if len(table) == 0 || len(fields) == 0 {
    34  		return nil, fmt.Errorf("empty table name or fields list")
    35  	}
    36  
    37  	err := validMariaDBUnquotedIdentifier(table)
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  	for _, field := range fields {
    42  		err := validMariaDBUnquotedIdentifier(field)
    43  		if err != nil {
    44  			return nil, err
    45  		}
    46  	}
    47  	if returningColumn != "" {
    48  		err := validMariaDBUnquotedIdentifier(returningColumn)
    49  		if err != nil {
    50  			return nil, err
    51  		}
    52  	}
    53  
    54  	return &MultiInserter{
    55  		table:           table,
    56  		fields:          fields,
    57  		returningColumn: returningColumn,
    58  		values:          make([][]interface{}, 0),
    59  	}, nil
    60  }
    61  
    62  // Add registers another row to be included in the Insert query.
    63  func (mi *MultiInserter) Add(row []interface{}) error {
    64  	if len(row) != len(mi.fields) {
    65  		return fmt.Errorf("field count mismatch, got %d, expected %d", len(row), len(mi.fields))
    66  	}
    67  	mi.values = append(mi.values, row)
    68  	return nil
    69  }
    70  
    71  // query returns the formatted query string, and the slice of arguments for
    72  // for borp to use in place of the query's question marks. Currently only
    73  // used by .Insert(), below.
    74  func (mi *MultiInserter) query() (string, []interface{}) {
    75  	var questionsBuf strings.Builder
    76  	var queryArgs []interface{}
    77  	for _, row := range mi.values {
    78  		// Safety: We are interpolating a string that will be used in a SQL
    79  		// query, but we constructed that string in this function and know it
    80  		// consists only of question marks joined with commas.
    81  		fmt.Fprintf(&questionsBuf, "(%s),", QuestionMarks(len(mi.fields)))
    82  		queryArgs = append(queryArgs, row...)
    83  	}
    84  
    85  	questions := strings.TrimRight(questionsBuf.String(), ",")
    86  
    87  	// Safety: we are interpolating `mi.returningColumn` into an SQL query. We
    88  	// know it is a valid unquoted identifier in MariaDB because we verified
    89  	// that in the constructor.
    90  	returning := ""
    91  	if mi.returningColumn != "" {
    92  		returning = fmt.Sprintf(" RETURNING %s", mi.returningColumn)
    93  	}
    94  	// Safety: we are interpolating `mi.table` and `mi.fields` into an SQL
    95  	// query. We know they contain, respectively, a valid unquoted identifier
    96  	// and a slice of valid unquoted identifiers because we verified that in
    97  	// the constructor. We know the query overall has valid syntax because we
    98  	// generate it entirely within this function.
    99  	query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s%s", mi.table, strings.Join(mi.fields, ","), questions, returning)
   100  
   101  	return query, queryArgs
   102  }
   103  
   104  // Insert inserts all the collected rows into the database represented by
   105  // `queryer`. If a non-empty returningColumn was provided, then it returns
   106  // the list of values from that column returned by the query.
   107  func (mi *MultiInserter) Insert(ctx context.Context, queryer Queryer) ([]int64, error) {
   108  	query, queryArgs := mi.query()
   109  	rows, err := queryer.QueryContext(ctx, query, queryArgs...)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	ids := make([]int64, 0, len(mi.values))
   115  	if mi.returningColumn != "" {
   116  		for rows.Next() {
   117  			var id int64
   118  			err = rows.Scan(&id)
   119  			if err != nil {
   120  				rows.Close()
   121  				return nil, err
   122  			}
   123  			ids = append(ids, id)
   124  		}
   125  	}
   126  
   127  	// Hack: sometimes in unittests we make a mock Queryer that returns a nil
   128  	// `*sql.Rows`. A nil `*sql.Rows` is not actually valid— calling `Close()`
   129  	// on it will panic— but here we choose to treat it like an empty list,
   130  	// and skip calling `Close()` to avoid the panic.
   131  	if rows != nil {
   132  		err = rows.Close()
   133  		if err != nil {
   134  			return nil, err
   135  		}
   136  	}
   137  
   138  	return ids, nil
   139  }
   140  

View as plain text