...

Source file src/github.com/doug-martin/goqu/v9/sqlgen/insert_sql_generator.go

Documentation: github.com/doug-martin/goqu/v9/sqlgen

     1  package sqlgen
     2  
     3  import (
     4  	"strings"
     5  
     6  	"github.com/doug-martin/goqu/v9/exp"
     7  	"github.com/doug-martin/goqu/v9/internal/errors"
     8  	"github.com/doug-martin/goqu/v9/internal/sb"
     9  )
    10  
    11  type (
    12  	// An adapter interface to be used by a Dataset to generate SQL for a specific dialect.
    13  	// See DefaultAdapter for a concrete implementation and examples.
    14  	InsertSQLGenerator interface {
    15  		Dialect() string
    16  		Generate(b sb.SQLBuilder, clauses exp.InsertClauses)
    17  	}
    18  	// The default adapter. This class should be used when building a new adapter. When creating a new adapter you can
    19  	// either override methods, or more typically update default values.
    20  	// See (github.com/doug-martin/goqu/dialect/postgres)
    21  	insertSQLGenerator struct {
    22  		CommonSQLGenerator
    23  	}
    24  )
    25  
    26  var (
    27  	ErrConflictUpdateValuesRequired = errors.New("values are required for on conflict update expression")
    28  	ErrNoSourceForInsert            = errors.New("no source found when generating insert sql")
    29  )
    30  
    31  func errMisMatchedRowLength(expectedL, actualL int) error {
    32  	return errors.New("rows with different value length expected %d got %d", expectedL, actualL)
    33  }
    34  
    35  func errUpsertWithWhereNotSupported(dialect string) error {
    36  	return errors.New("dialect does not support upsert with where clause [dialect=%s]", dialect)
    37  }
    38  
    39  func NewInsertSQLGenerator(dialect string, do *SQLDialectOptions) InsertSQLGenerator {
    40  	return &insertSQLGenerator{NewCommonSQLGenerator(dialect, do)}
    41  }
    42  
    43  func (isg *insertSQLGenerator) Generate(
    44  	b sb.SQLBuilder,
    45  	clauses exp.InsertClauses,
    46  ) {
    47  	if !clauses.HasInto() {
    48  		b.SetError(ErrNoSourceForInsert)
    49  		return
    50  	}
    51  	for _, f := range isg.DialectOptions().InsertSQLOrder {
    52  		if b.Error() != nil {
    53  			return
    54  		}
    55  		switch f {
    56  		case CommonTableSQLFragment:
    57  			isg.ExpressionSQLGenerator().Generate(b, clauses.CommonTables())
    58  		case InsertBeingSQLFragment:
    59  			isg.InsertBeginSQL(b, clauses.OnConflict())
    60  		case IntoSQLFragment:
    61  			b.WriteRunes(isg.DialectOptions().SpaceRune)
    62  			isg.ExpressionSQLGenerator().Generate(b, clauses.Into())
    63  		case InsertSQLFragment:
    64  			isg.InsertSQL(b, clauses)
    65  		case ReturningSQLFragment:
    66  			isg.ReturningSQL(b, clauses.Returning())
    67  		default:
    68  			b.SetError(ErrNotSupportedFragment("INSERT", f))
    69  		}
    70  	}
    71  }
    72  
    73  // Adds the correct fragment to being an INSERT statement
    74  func (isg *insertSQLGenerator) InsertBeginSQL(b sb.SQLBuilder, o exp.ConflictExpression) {
    75  	if isg.DialectOptions().SupportsInsertIgnoreSyntax && o != nil {
    76  		b.Write(isg.DialectOptions().InsertIgnoreClause)
    77  	} else {
    78  		b.Write(isg.DialectOptions().InsertClause)
    79  	}
    80  }
    81  
    82  // Adds the columns list to an insert statement
    83  func (isg *insertSQLGenerator) InsertSQL(b sb.SQLBuilder, ic exp.InsertClauses) {
    84  	switch {
    85  	case ic.HasRows():
    86  		ie, err := exp.NewInsertExpression(ic.Rows()...)
    87  		if err != nil {
    88  			b.SetError(err)
    89  			return
    90  		}
    91  		isg.InsertExpressionSQL(b, ie)
    92  	case ic.HasCols() && ic.HasVals():
    93  		isg.insertColumnsSQL(b, ic.Cols())
    94  		isg.insertValuesSQL(b, ic.Vals())
    95  	case ic.HasCols() && ic.HasFrom():
    96  		isg.insertColumnsSQL(b, ic.Cols())
    97  		isg.insertFromSQL(b, ic.From())
    98  	case ic.HasFrom():
    99  		isg.insertFromSQL(b, ic.From())
   100  	default:
   101  		isg.defaultValuesSQL(b)
   102  	}
   103  	if ic.HasAlias() {
   104  		b.Write(isg.DialectOptions().AsFragment)
   105  		isg.ExpressionSQLGenerator().Generate(b, ic.Alias())
   106  	}
   107  	isg.onConflictSQL(b, ic.OnConflict())
   108  }
   109  
   110  func (isg *insertSQLGenerator) InsertExpressionSQL(b sb.SQLBuilder, ie exp.InsertExpression) {
   111  	switch {
   112  	case ie.IsInsertFrom():
   113  		isg.insertFromSQL(b, ie.From())
   114  	case ie.IsEmpty():
   115  		isg.defaultValuesSQL(b)
   116  	default:
   117  		isg.insertColumnsSQL(b, ie.Cols())
   118  		isg.insertValuesSQL(b, ie.Vals())
   119  	}
   120  }
   121  
   122  // Adds the DefaultValuesFragment to an SQL statement
   123  func (isg *insertSQLGenerator) defaultValuesSQL(b sb.SQLBuilder) {
   124  	b.Write(isg.DialectOptions().DefaultValuesFragment)
   125  }
   126  
   127  func (isg *insertSQLGenerator) insertFromSQL(b sb.SQLBuilder, ae exp.AppendableExpression) {
   128  	b.WriteRunes(isg.DialectOptions().SpaceRune)
   129  	ae.AppendSQL(b)
   130  }
   131  
   132  // Adds the columns list to an insert statement
   133  func (isg *insertSQLGenerator) insertColumnsSQL(b sb.SQLBuilder, cols exp.ColumnListExpression) {
   134  	b.WriteRunes(isg.DialectOptions().SpaceRune, isg.DialectOptions().LeftParenRune)
   135  	isg.ExpressionSQLGenerator().Generate(b, cols)
   136  	b.WriteRunes(isg.DialectOptions().RightParenRune)
   137  }
   138  
   139  // Adds the values clause to an SQL statement
   140  func (isg *insertSQLGenerator) insertValuesSQL(b sb.SQLBuilder, values [][]interface{}) {
   141  	b.Write(isg.DialectOptions().ValuesFragment)
   142  	rowLen := len(values[0])
   143  	valueLen := len(values)
   144  	for i, row := range values {
   145  		if len(row) != rowLen {
   146  			b.SetError(errMisMatchedRowLength(rowLen, len(row)))
   147  			return
   148  		}
   149  		isg.ExpressionSQLGenerator().Generate(b, row)
   150  		if i < valueLen-1 {
   151  			b.WriteRunes(isg.DialectOptions().CommaRune, isg.DialectOptions().SpaceRune)
   152  		}
   153  	}
   154  }
   155  
   156  // Adds the DefaultValuesFragment to an SQL statement
   157  func (isg *insertSQLGenerator) onConflictSQL(b sb.SQLBuilder, o exp.ConflictExpression) {
   158  	if o == nil {
   159  		return
   160  	}
   161  	b.Write(isg.DialectOptions().ConflictFragment)
   162  	switch t := o.(type) {
   163  	case exp.ConflictUpdateExpression:
   164  		target := t.TargetColumn()
   165  		if isg.DialectOptions().SupportsConflictTarget && target != "" {
   166  			wrapParens := !strings.HasPrefix(strings.ToLower(target), "on constraint")
   167  
   168  			b.WriteRunes(isg.DialectOptions().SpaceRune)
   169  			if wrapParens {
   170  				b.WriteRunes(isg.DialectOptions().LeftParenRune).
   171  					WriteStrings(target).
   172  					WriteRunes(isg.DialectOptions().RightParenRune)
   173  			} else {
   174  				b.Write([]byte(target))
   175  			}
   176  		}
   177  		isg.onConflictDoUpdateSQL(b, t)
   178  	default:
   179  		b.Write(isg.DialectOptions().ConflictDoNothingFragment)
   180  	}
   181  }
   182  
   183  func (isg *insertSQLGenerator) onConflictDoUpdateSQL(b sb.SQLBuilder, o exp.ConflictUpdateExpression) {
   184  	b.Write(isg.DialectOptions().ConflictDoUpdateFragment)
   185  	update := o.Update()
   186  	if update == nil {
   187  		b.SetError(ErrConflictUpdateValuesRequired)
   188  		return
   189  	}
   190  	ue, err := exp.NewUpdateExpressions(update)
   191  	if err != nil {
   192  		b.SetError(err)
   193  		return
   194  	}
   195  	isg.UpdateExpressionSQL(b, ue...)
   196  	if b.Error() == nil && o.WhereClause() != nil {
   197  		if !isg.DialectOptions().SupportsConflictUpdateWhere {
   198  			b.SetError(errUpsertWithWhereNotSupported(isg.Dialect()))
   199  			return
   200  		}
   201  		isg.WhereSQL(b, o.WhereClause())
   202  	}
   203  }
   204  

View as plain text