1 package sqlgen
2
3 import (
4 "strings"
5
6 "github.com/doug-martin/goqu/v9/exp"
7 "github.com/doug-martin/goqu/v9/internal/errors"
8 "github.com/doug-martin/goqu/v9/internal/sb"
9 )
10
11 type (
12
13
14 InsertSQLGenerator interface {
15 Dialect() string
16 Generate(b sb.SQLBuilder, clauses exp.InsertClauses)
17 }
18
19
20
21 insertSQLGenerator struct {
22 CommonSQLGenerator
23 }
24 )
25
26 var (
27 ErrConflictUpdateValuesRequired = errors.New("values are required for on conflict update expression")
28 ErrNoSourceForInsert = errors.New("no source found when generating insert sql")
29 )
30
31 func errMisMatchedRowLength(expectedL, actualL int) error {
32 return errors.New("rows with different value length expected %d got %d", expectedL, actualL)
33 }
34
35 func errUpsertWithWhereNotSupported(dialect string) error {
36 return errors.New("dialect does not support upsert with where clause [dialect=%s]", dialect)
37 }
38
39 func NewInsertSQLGenerator(dialect string, do *SQLDialectOptions) InsertSQLGenerator {
40 return &insertSQLGenerator{NewCommonSQLGenerator(dialect, do)}
41 }
42
43 func (isg *insertSQLGenerator) Generate(
44 b sb.SQLBuilder,
45 clauses exp.InsertClauses,
46 ) {
47 if !clauses.HasInto() {
48 b.SetError(ErrNoSourceForInsert)
49 return
50 }
51 for _, f := range isg.DialectOptions().InsertSQLOrder {
52 if b.Error() != nil {
53 return
54 }
55 switch f {
56 case CommonTableSQLFragment:
57 isg.ExpressionSQLGenerator().Generate(b, clauses.CommonTables())
58 case InsertBeingSQLFragment:
59 isg.InsertBeginSQL(b, clauses.OnConflict())
60 case IntoSQLFragment:
61 b.WriteRunes(isg.DialectOptions().SpaceRune)
62 isg.ExpressionSQLGenerator().Generate(b, clauses.Into())
63 case InsertSQLFragment:
64 isg.InsertSQL(b, clauses)
65 case ReturningSQLFragment:
66 isg.ReturningSQL(b, clauses.Returning())
67 default:
68 b.SetError(ErrNotSupportedFragment("INSERT", f))
69 }
70 }
71 }
72
73
74 func (isg *insertSQLGenerator) InsertBeginSQL(b sb.SQLBuilder, o exp.ConflictExpression) {
75 if isg.DialectOptions().SupportsInsertIgnoreSyntax && o != nil {
76 b.Write(isg.DialectOptions().InsertIgnoreClause)
77 } else {
78 b.Write(isg.DialectOptions().InsertClause)
79 }
80 }
81
82
83 func (isg *insertSQLGenerator) InsertSQL(b sb.SQLBuilder, ic exp.InsertClauses) {
84 switch {
85 case ic.HasRows():
86 ie, err := exp.NewInsertExpression(ic.Rows()...)
87 if err != nil {
88 b.SetError(err)
89 return
90 }
91 isg.InsertExpressionSQL(b, ie)
92 case ic.HasCols() && ic.HasVals():
93 isg.insertColumnsSQL(b, ic.Cols())
94 isg.insertValuesSQL(b, ic.Vals())
95 case ic.HasCols() && ic.HasFrom():
96 isg.insertColumnsSQL(b, ic.Cols())
97 isg.insertFromSQL(b, ic.From())
98 case ic.HasFrom():
99 isg.insertFromSQL(b, ic.From())
100 default:
101 isg.defaultValuesSQL(b)
102 }
103 if ic.HasAlias() {
104 b.Write(isg.DialectOptions().AsFragment)
105 isg.ExpressionSQLGenerator().Generate(b, ic.Alias())
106 }
107 isg.onConflictSQL(b, ic.OnConflict())
108 }
109
110 func (isg *insertSQLGenerator) InsertExpressionSQL(b sb.SQLBuilder, ie exp.InsertExpression) {
111 switch {
112 case ie.IsInsertFrom():
113 isg.insertFromSQL(b, ie.From())
114 case ie.IsEmpty():
115 isg.defaultValuesSQL(b)
116 default:
117 isg.insertColumnsSQL(b, ie.Cols())
118 isg.insertValuesSQL(b, ie.Vals())
119 }
120 }
121
122
123 func (isg *insertSQLGenerator) defaultValuesSQL(b sb.SQLBuilder) {
124 b.Write(isg.DialectOptions().DefaultValuesFragment)
125 }
126
127 func (isg *insertSQLGenerator) insertFromSQL(b sb.SQLBuilder, ae exp.AppendableExpression) {
128 b.WriteRunes(isg.DialectOptions().SpaceRune)
129 ae.AppendSQL(b)
130 }
131
132
133 func (isg *insertSQLGenerator) insertColumnsSQL(b sb.SQLBuilder, cols exp.ColumnListExpression) {
134 b.WriteRunes(isg.DialectOptions().SpaceRune, isg.DialectOptions().LeftParenRune)
135 isg.ExpressionSQLGenerator().Generate(b, cols)
136 b.WriteRunes(isg.DialectOptions().RightParenRune)
137 }
138
139
140 func (isg *insertSQLGenerator) insertValuesSQL(b sb.SQLBuilder, values [][]interface{}) {
141 b.Write(isg.DialectOptions().ValuesFragment)
142 rowLen := len(values[0])
143 valueLen := len(values)
144 for i, row := range values {
145 if len(row) != rowLen {
146 b.SetError(errMisMatchedRowLength(rowLen, len(row)))
147 return
148 }
149 isg.ExpressionSQLGenerator().Generate(b, row)
150 if i < valueLen-1 {
151 b.WriteRunes(isg.DialectOptions().CommaRune, isg.DialectOptions().SpaceRune)
152 }
153 }
154 }
155
156
157 func (isg *insertSQLGenerator) onConflictSQL(b sb.SQLBuilder, o exp.ConflictExpression) {
158 if o == nil {
159 return
160 }
161 b.Write(isg.DialectOptions().ConflictFragment)
162 switch t := o.(type) {
163 case exp.ConflictUpdateExpression:
164 target := t.TargetColumn()
165 if isg.DialectOptions().SupportsConflictTarget && target != "" {
166 wrapParens := !strings.HasPrefix(strings.ToLower(target), "on constraint")
167
168 b.WriteRunes(isg.DialectOptions().SpaceRune)
169 if wrapParens {
170 b.WriteRunes(isg.DialectOptions().LeftParenRune).
171 WriteStrings(target).
172 WriteRunes(isg.DialectOptions().RightParenRune)
173 } else {
174 b.Write([]byte(target))
175 }
176 }
177 isg.onConflictDoUpdateSQL(b, t)
178 default:
179 b.Write(isg.DialectOptions().ConflictDoNothingFragment)
180 }
181 }
182
183 func (isg *insertSQLGenerator) onConflictDoUpdateSQL(b sb.SQLBuilder, o exp.ConflictUpdateExpression) {
184 b.Write(isg.DialectOptions().ConflictDoUpdateFragment)
185 update := o.Update()
186 if update == nil {
187 b.SetError(ErrConflictUpdateValuesRequired)
188 return
189 }
190 ue, err := exp.NewUpdateExpressions(update)
191 if err != nil {
192 b.SetError(err)
193 return
194 }
195 isg.UpdateExpressionSQL(b, ue...)
196 if b.Error() == nil && o.WhereClause() != nil {
197 if !isg.DialectOptions().SupportsConflictUpdateWhere {
198 b.SetError(errUpsertWithWhereNotSupported(isg.Dialect()))
199 return
200 }
201 isg.WhereSQL(b, o.WhereClause())
202 }
203 }
204
View as plain text