...
1 package sqlgen
2
3 import (
4 "github.com/doug-martin/goqu/v9/exp"
5 "github.com/doug-martin/goqu/v9/internal/errors"
6 "github.com/doug-martin/goqu/v9/internal/sb"
7 )
8
9 type (
10
11
12 UpdateSQLGenerator interface {
13 Dialect() string
14 Generate(b sb.SQLBuilder, clauses exp.UpdateClauses)
15 }
16
17
18
19 updateSQLGenerator struct {
20 CommonSQLGenerator
21 }
22 )
23
24 var (
25 ErrNoSourceForUpdate = errors.New("no source found when generating update sql")
26 ErrNoSetValuesForUpdate = errors.New("no set values found when generating UPDATE sql")
27 )
28
29 func NewUpdateSQLGenerator(dialect string, do *SQLDialectOptions) UpdateSQLGenerator {
30 return &updateSQLGenerator{NewCommonSQLGenerator(dialect, do)}
31 }
32
33 func (usg *updateSQLGenerator) Generate(b sb.SQLBuilder, clauses exp.UpdateClauses) {
34 if !clauses.HasTable() {
35 b.SetError(ErrNoSourceForUpdate)
36 return
37 }
38 if !clauses.HasSetValues() {
39 b.SetError(ErrNoSetValuesForUpdate)
40 return
41 }
42 if !usg.DialectOptions().SupportsMultipleUpdateTables && clauses.HasFrom() {
43 b.SetError(errors.New("%s dialect does not support multiple tables in UPDATE", usg.Dialect()))
44 }
45 updates, err := exp.NewUpdateExpressions(clauses.SetValues())
46 if err != nil {
47 b.SetError(err)
48 return
49 }
50 for _, f := range usg.DialectOptions().UpdateSQLOrder {
51 if b.Error() != nil {
52 return
53 }
54 switch f {
55 case CommonTableSQLFragment:
56 usg.ExpressionSQLGenerator().Generate(b, clauses.CommonTables())
57 case UpdateBeginSQLFragment:
58 usg.UpdateBeginSQL(b)
59 case SourcesSQLFragment:
60 usg.updateTableSQL(b, clauses)
61 case UpdateSQLFragment:
62 usg.UpdateExpressionsSQL(b, updates...)
63 case UpdateFromSQLFragment:
64 usg.updateFromSQL(b, clauses.From())
65 case WhereSQLFragment:
66 usg.WhereSQL(b, clauses.Where())
67 case OrderSQLFragment:
68 if usg.DialectOptions().SupportsOrderByOnUpdate {
69 usg.OrderSQL(b, clauses.Order())
70 }
71 case LimitSQLFragment:
72 if usg.DialectOptions().SupportsLimitOnUpdate {
73 usg.LimitSQL(b, clauses.Limit())
74 }
75 case ReturningSQLFragment:
76 usg.ReturningSQL(b, clauses.Returning())
77 default:
78 b.SetError(ErrNotSupportedFragment("UPDATE", f))
79 }
80 }
81 }
82
83
84 func (usg *updateSQLGenerator) UpdateBeginSQL(b sb.SQLBuilder) {
85 b.Write(usg.DialectOptions().UpdateClause)
86 }
87
88
89 func (usg *updateSQLGenerator) UpdateExpressionsSQL(b sb.SQLBuilder, updates ...exp.UpdateExpression) {
90 b.Write(usg.DialectOptions().SetFragment)
91 usg.UpdateExpressionSQL(b, updates...)
92 }
93
94 func (usg *updateSQLGenerator) updateTableSQL(b sb.SQLBuilder, uc exp.UpdateClauses) {
95 b.WriteRunes(usg.DialectOptions().SpaceRune)
96 usg.ExpressionSQLGenerator().Generate(b, uc.Table())
97 if uc.HasFrom() {
98 if !usg.DialectOptions().UseFromClauseForMultipleUpdateTables {
99 b.WriteRunes(usg.DialectOptions().CommaRune)
100 usg.ExpressionSQLGenerator().Generate(b, uc.From())
101 }
102 }
103 }
104
105 func (usg *updateSQLGenerator) updateFromSQL(b sb.SQLBuilder, ce exp.ColumnListExpression) {
106 if ce == nil || ce.IsEmpty() {
107 return
108 }
109 if usg.DialectOptions().UseFromClauseForMultipleUpdateTables {
110 usg.FromSQL(b, ce)
111 }
112 }
113
View as plain text