1 package exp
2
3 import (
4 "reflect"
5 "sort"
6
7 "github.com/doug-martin/goqu/v9/internal/errors"
8 "github.com/doug-martin/goqu/v9/internal/util"
9 )
10
11 type (
12 insert struct {
13 from AppendableExpression
14 cols ColumnListExpression
15 vals [][]interface{}
16 }
17 )
18
19 func NewInsertExpression(rows ...interface{}) (insertExpression InsertExpression, err error) {
20 switch len(rows) {
21 case 0:
22 return new(insert), nil
23 case 1:
24 val := reflect.ValueOf(rows[0])
25 if val.Kind() == reflect.Slice {
26 vals := make([]interface{}, 0, val.Len())
27 for i := 0; i < val.Len(); i++ {
28 vals = append(vals, val.Index(i).Interface())
29 }
30 return NewInsertExpression(vals...)
31 }
32 if ae, ok := rows[0].(AppendableExpression); ok {
33 return &insert{from: ae}, nil
34 }
35 }
36 return newInsert(rows...)
37 }
38
39 func (i *insert) Expression() Expression {
40 return i
41 }
42
43 func (i *insert) Clone() Expression {
44 return i.clone()
45 }
46
47 func (i *insert) clone() *insert {
48 return &insert{from: i.from, cols: i.cols, vals: i.vals}
49 }
50
51 func (i *insert) IsEmpty() bool {
52 return i.from == nil && (i.cols == nil || i.cols.IsEmpty())
53 }
54
55 func (i *insert) IsInsertFrom() bool {
56 return i.from != nil
57 }
58
59 func (i *insert) From() AppendableExpression {
60 return i.from
61 }
62
63 func (i *insert) Cols() ColumnListExpression {
64 return i.cols
65 }
66
67 func (i *insert) SetCols(cols ColumnListExpression) InsertExpression {
68 ci := i.clone()
69 ci.cols = cols
70 return ci
71 }
72
73 func (i *insert) Vals() [][]interface{} {
74 return i.vals
75 }
76
77 func (i *insert) SetVals(vals [][]interface{}) InsertExpression {
78 ci := i.clone()
79 ci.vals = vals
80 return ci
81 }
82
83
84 func newInsert(rows ...interface{}) (insertExp InsertExpression, err error) {
85 var mapKeys util.ValueSlice
86 rowValue := reflect.Indirect(reflect.ValueOf(rows[0]))
87 rowType := rowValue.Type()
88 rowKind := rowValue.Kind()
89 if rowKind == reflect.Struct {
90 return createStructSliceInsert(rows...)
91 }
92 vals := make([][]interface{}, 0, len(rows))
93 var columns ColumnListExpression
94 for _, row := range rows {
95 if rowType != reflect.Indirect(reflect.ValueOf(row)).Type() {
96 return nil, errors.New(
97 "rows must be all the same type expected %+v got %+v",
98 rowType,
99 reflect.Indirect(reflect.ValueOf(row)).Type(),
100 )
101 }
102 newRowValue := reflect.Indirect(reflect.ValueOf(row))
103 switch rowKind {
104 case reflect.Map:
105 if columns == nil {
106 mapKeys = util.ValueSlice(newRowValue.MapKeys())
107 sort.Sort(mapKeys)
108 colKeys := make([]interface{}, 0, len(mapKeys))
109 for _, key := range mapKeys {
110 colKeys = append(colKeys, key.Interface())
111 }
112 columns = NewColumnListExpression(colKeys...)
113 }
114 newMapKeys := util.ValueSlice(newRowValue.MapKeys())
115 if len(newMapKeys) != len(mapKeys) {
116 return nil, errors.New("rows with different value length expected %d got %d", len(mapKeys), len(newMapKeys))
117 }
118 if !mapKeys.Equal(newMapKeys) {
119 return nil, errors.New("rows with different keys expected %s got %s", mapKeys.String(), newMapKeys.String())
120 }
121 rowVals := make([]interface{}, 0, len(mapKeys))
122 for _, key := range mapKeys {
123 rowVals = append(rowVals, newRowValue.MapIndex(key).Interface())
124 }
125 vals = append(vals, rowVals)
126 default:
127 return nil, errors.New(
128 "unsupported insert must be map, goqu.Record, or struct type got: %T",
129 row,
130 )
131 }
132 }
133 return &insert{cols: columns, vals: vals}, nil
134 }
135
136 func createStructSliceInsert(rows ...interface{}) (insertExp InsertExpression, err error) {
137 rowValue := reflect.Indirect(reflect.ValueOf(rows[0]))
138 rowType := rowValue.Type()
139 recordRows := make([]interface{}, 0, len(rows))
140 for _, row := range rows {
141 if rowType != reflect.Indirect(reflect.ValueOf(row)).Type() {
142 return nil, errors.New(
143 "rows must be all the same type expected %+v got %+v",
144 rowType,
145 reflect.Indirect(reflect.ValueOf(row)).Type(),
146 )
147 }
148 newRowValue := reflect.Indirect(reflect.ValueOf(row))
149 record, err := getFieldsValuesFromStruct(newRowValue)
150 if err != nil {
151 return nil, err
152 }
153 recordRows = append(recordRows, record)
154 }
155 return newInsert(recordRows...)
156 }
157
158 func getFieldsValuesFromStruct(value reflect.Value) (row Record, err error) {
159 if value.IsValid() {
160 return NewRecordFromStruct(value.Interface(), true, false)
161 }
162 return
163 }
164
View as plain text