...

Source file src/github.com/letsencrypt/boulder/db/gorm.go

Documentation: github.com/letsencrypt/boulder/db

     1  package db
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"reflect"
     8  	"regexp"
     9  	"strings"
    10  )
    11  
    12  // Characters allowed in an unquoted identifier by MariaDB.
    13  // https://mariadb.com/kb/en/identifier-names/#unquoted
    14  var mariaDBUnquotedIdentifierRE = regexp.MustCompile("^[0-9a-zA-Z$_]+$")
    15  
    16  func validMariaDBUnquotedIdentifier(s string) error {
    17  	if !mariaDBUnquotedIdentifierRE.MatchString(s) {
    18  		return fmt.Errorf("invalid MariaDB identifier %q", s)
    19  	}
    20  
    21  	allNumeric := true
    22  	startsNumeric := false
    23  	for i, c := range []byte(s) {
    24  		if c < '0' || c > '9' {
    25  			if startsNumeric && len(s) > i && s[i] == 'e' {
    26  				return fmt.Errorf("MariaDB identifier looks like floating point: %q", s)
    27  			}
    28  			allNumeric = false
    29  			break
    30  		}
    31  		startsNumeric = true
    32  	}
    33  	if allNumeric {
    34  		return fmt.Errorf("MariaDB identifier contains only numerals: %q", s)
    35  	}
    36  	return nil
    37  }
    38  
    39  // NewMappedSelector returns an object which can be used to automagically query
    40  // the provided type-mapped database for rows of the parameterized type.
    41  func NewMappedSelector[T any](executor MappedExecutor) (MappedSelector[T], error) {
    42  	var throwaway T
    43  	t := reflect.TypeOf(throwaway)
    44  
    45  	// We use a very strict mapping of struct fields to table columns here:
    46  	// - The struct must not have any embedded structs, only named fields.
    47  	// - The struct field names must be case-insensitively identical to the
    48  	//   column names (no struct tags necessary).
    49  	// - The struct field names must be case-insensitively unique.
    50  	// - Every field of the struct must correspond to a database column.
    51  	//   - Note that the reverse is not true: it's perfectly okay for there to be
    52  	//     database columns which do not correspond to fields in the struct; those
    53  	//     columns will be ignored.
    54  	// TODO: In the future, when we replace borp's TableMap with our own, this
    55  	// check should be performed at the time the mapping is declared.
    56  	columns := make([]string, 0)
    57  	seen := make(map[string]struct{})
    58  	for i := 0; i < t.NumField(); i++ {
    59  		field := t.Field(i)
    60  		if field.Anonymous {
    61  			return nil, fmt.Errorf("struct contains anonymous embedded struct %q", field.Name)
    62  		}
    63  		column := strings.ToLower(t.Field(i).Name)
    64  		err := validMariaDBUnquotedIdentifier(column)
    65  		if err != nil {
    66  			return nil, fmt.Errorf("struct field maps to unsafe db column name %q", column)
    67  		}
    68  		if _, found := seen[column]; found {
    69  			return nil, fmt.Errorf("struct fields map to duplicate column name %q", column)
    70  		}
    71  		seen[column] = struct{}{}
    72  		columns = append(columns, column)
    73  	}
    74  
    75  	return &mappedSelector[T]{wrapped: executor, columns: columns}, nil
    76  }
    77  
    78  type mappedSelector[T any] struct {
    79  	wrapped MappedExecutor
    80  	columns []string
    81  }
    82  
    83  // QueryContext performs a SELECT on the appropriate table for T. It combines the best
    84  // features of borp, the go stdlib, and generics, using the type parameter of
    85  // the typeSelector object to automatically look up the proper table name and
    86  // columns to select. It returns an iterable which yields fully-populated
    87  // objects of the parameterized type directly. The given clauses MUST be only
    88  // the bits of a sql query from "WHERE ..." onwards; if they contain any of the
    89  // "SELECT ... FROM ..." portion of the query it will result in an error. The
    90  // args take the same kinds of values as borp's SELECT: either one argument per
    91  // positional placeholder, or a map of placeholder names to their arguments
    92  // (see https://pkg.go.dev/github.com/letsencrypt/borp#readme-ad-hoc-sql).
    93  //
    94  // The caller is responsible for calling `Rows.Close()` when they are done with
    95  // the query. The caller is also responsible for ensuring that the clauses
    96  // argument does not contain any user-influenced input.
    97  func (ts mappedSelector[T]) QueryContext(ctx context.Context, clauses string, args ...interface{}) (Rows[T], error) {
    98  	// Look up the table to use based on the type of this TypeSelector.
    99  	var throwaway T
   100  	tableMap, err := ts.wrapped.TableFor(reflect.TypeOf(throwaway), false)
   101  	if err != nil {
   102  		return nil, fmt.Errorf("database model type not mapped to table name: %w", err)
   103  	}
   104  
   105  	return ts.QueryFrom(ctx, tableMap.TableName, clauses, args...)
   106  }
   107  
   108  // QueryFrom is the same as Query, but it additionally takes a table name to
   109  // select from, rather than automatically computing the table name from borp's
   110  // DbMap.
   111  //
   112  // The caller is responsible for calling `Rows.Close()` when they are done with
   113  // the query. The caller is also responsible for ensuring that the clauses
   114  // argument does not contain any user-influenced input.
   115  func (ts mappedSelector[T]) QueryFrom(ctx context.Context, tablename string, clauses string, args ...interface{}) (Rows[T], error) {
   116  	err := validMariaDBUnquotedIdentifier(tablename)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  
   121  	// Construct the query from the column names, table name, and given clauses.
   122  	// Note that the column names here are in the order given by
   123  	query := fmt.Sprintf(
   124  		"SELECT %s FROM %s %s",
   125  		strings.Join(ts.columns, ", "),
   126  		tablename,
   127  		clauses,
   128  	)
   129  
   130  	r, err := ts.wrapped.QueryContext(ctx, query, args...)
   131  	if err != nil {
   132  		return nil, fmt.Errorf("reading db: %w", err)
   133  	}
   134  
   135  	return &rows[T]{wrapped: r, numCols: len(ts.columns)}, nil
   136  }
   137  
   138  // rows is a wrapper around the stdlib's sql.rows, but with a more
   139  // type-safe method to get actual row content.
   140  type rows[T any] struct {
   141  	wrapped *sql.Rows
   142  	numCols int
   143  }
   144  
   145  // Next is a wrapper around sql.Rows.Next(). It must be called before every call
   146  // to Get(), including the first.
   147  func (r rows[T]) Next() bool {
   148  	return r.wrapped.Next()
   149  }
   150  
   151  // Get is a wrapper around sql.Rows.Scan(). Rather than populating an arbitrary
   152  // number of &interface{} arguments, it returns a populated object of the
   153  // parameterized type.
   154  func (r rows[T]) Get() (*T, error) {
   155  	result := new(T)
   156  	v := reflect.ValueOf(result)
   157  
   158  	// Because sql.Rows.Scan(...) takes a variadic number of individual targets to
   159  	// read values into, build a slice that can be splatted into the call. Use the
   160  	// pre-computed list of in-order column names to populate it.
   161  	scanTargets := make([]interface{}, r.numCols)
   162  	for i := range scanTargets {
   163  		field := v.Elem().Field(i)
   164  		scanTargets[i] = field.Addr().Interface()
   165  	}
   166  
   167  	err := r.wrapped.Scan(scanTargets...)
   168  	if err != nil {
   169  		return nil, fmt.Errorf("reading db row: %w", err)
   170  	}
   171  
   172  	return result, nil
   173  }
   174  
   175  // Err is a wrapper around sql.Rows.Err(). It should be checked immediately
   176  // after Next() returns false for any reason.
   177  func (r rows[T]) Err() error {
   178  	return r.wrapped.Err()
   179  }
   180  
   181  // Close is a wrapper around sql.Rows.Close(). It must be called when the caller
   182  // is done reading rows, regardless of success or error.
   183  func (r rows[T]) Close() error {
   184  	return r.wrapped.Close()
   185  }
   186  

View as plain text