1 package util
2
3 import (
4 "database/sql"
5 "reflect"
6 "strings"
7 "sync"
8
9 "github.com/doug-martin/goqu/v9/internal/errors"
10 )
11
12 const (
13 skipUpdateTagName = "skipupdate"
14 skipInsertTagName = "skipinsert"
15 defaultIfEmptyTagName = "defaultifempty"
16 )
17
18 var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
19
20 func IsUint(k reflect.Kind) bool {
21 return (k == reflect.Uint) ||
22 (k == reflect.Uint8) ||
23 (k == reflect.Uint16) ||
24 (k == reflect.Uint32) ||
25 (k == reflect.Uint64)
26 }
27
28 func IsInt(k reflect.Kind) bool {
29 return (k == reflect.Int) ||
30 (k == reflect.Int8) ||
31 (k == reflect.Int16) ||
32 (k == reflect.Int32) ||
33 (k == reflect.Int64)
34 }
35
36 func IsFloat(k reflect.Kind) bool {
37 return (k == reflect.Float32) ||
38 (k == reflect.Float64)
39 }
40
41 func IsString(k reflect.Kind) bool {
42 return k == reflect.String
43 }
44
45 func IsBool(k reflect.Kind) bool {
46 return k == reflect.Bool
47 }
48
49 func IsSlice(k reflect.Kind) bool {
50 return k == reflect.Slice
51 }
52
53 func IsStruct(k reflect.Kind) bool {
54 return k == reflect.Struct
55 }
56
57 func IsInvalid(k reflect.Kind) bool {
58 return k == reflect.Invalid
59 }
60
61 func IsPointer(k reflect.Kind) bool {
62 return k == reflect.Ptr
63 }
64
65 func IsEmptyValue(v reflect.Value) bool {
66 switch v.Kind() {
67 case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
68 return v.Len() == 0
69 case reflect.Bool:
70 return !v.Bool()
71 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
72 return v.Int() == 0
73 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
74 return v.Uint() == 0
75 case reflect.Float32, reflect.Float64:
76 return v.Float() == 0
77 case reflect.Interface, reflect.Ptr:
78 return v.IsNil()
79 case reflect.Invalid:
80 return true
81 default:
82 return false
83 }
84 }
85
86 var (
87 structMapCache = make(map[interface{}]ColumnMap)
88 structMapCacheLock = sync.Mutex{}
89 )
90
91 var (
92 DefaultColumnRenameFunction = strings.ToLower
93 columnRenameFunction = DefaultColumnRenameFunction
94 ignoreUntaggedFields = false
95 )
96
97 func SetIgnoreUntaggedFields(ignore bool) {
98
99 if ignore != ignoreUntaggedFields {
100 ignoreUntaggedFields = ignore
101
102 structMapCacheLock.Lock()
103 defer structMapCacheLock.Unlock()
104
105 structMapCache = make(map[interface{}]ColumnMap)
106 }
107 }
108
109 func SetColumnRenameFunction(newFunction func(string) string) {
110 columnRenameFunction = newFunction
111 }
112
113
114 func GetSliceElementType(val reflect.Value) reflect.Type {
115 elemType := val.Type().Elem()
116 if elemType.Kind() == reflect.Ptr {
117 elemType = elemType.Elem()
118 }
119
120 return elemType
121 }
122
123
124
125 func AppendSliceElement(slice, val reflect.Value) {
126 if slice.Type().Elem().Kind() == reflect.Ptr {
127 slice.Set(reflect.Append(slice, val))
128 } else {
129 slice.Set(reflect.Append(slice, reflect.Indirect(val)))
130 }
131 }
132
133 func GetTypeInfo(i interface{}, val reflect.Value) (reflect.Type, reflect.Kind) {
134 var t reflect.Type
135 valKind := val.Kind()
136 if valKind == reflect.Slice {
137 if reflect.ValueOf(i).Kind() == reflect.Ptr {
138 t = reflect.TypeOf(i).Elem().Elem()
139 } else {
140 t = reflect.TypeOf(i).Elem()
141 }
142 if t.Kind() == reflect.Ptr {
143 t = t.Elem()
144 }
145 valKind = t.Kind()
146 } else {
147 t = val.Type()
148 }
149 return t, valKind
150 }
151
152 func SafeGetFieldByIndex(v reflect.Value, fieldIndex []int) (result reflect.Value, isAvailable bool) {
153 switch len(fieldIndex) {
154 case 0:
155 return v, true
156 case 1:
157 return v.FieldByIndex(fieldIndex), true
158 default:
159 if f := reflect.Indirect(v.Field(fieldIndex[0])); f.IsValid() {
160 return SafeGetFieldByIndex(f, fieldIndex[1:])
161 }
162 }
163 return reflect.ValueOf(nil), false
164 }
165
166 func SafeSetFieldByIndex(v reflect.Value, fieldIndex []int, src interface{}) (result reflect.Value) {
167 v = reflect.Indirect(v)
168 switch len(fieldIndex) {
169 case 0:
170 return v
171 case 1:
172 f := v.FieldByIndex(fieldIndex)
173 srcVal := reflect.ValueOf(src)
174 f.Set(reflect.Indirect(srcVal))
175 default:
176 f := v.Field(fieldIndex[0])
177 switch f.Kind() {
178 case reflect.Ptr:
179 s := f
180 if f.IsNil() || !f.IsValid() {
181 s = reflect.New(f.Type().Elem())
182 f.Set(s)
183 }
184 SafeSetFieldByIndex(reflect.Indirect(s), fieldIndex[1:], src)
185 case reflect.Struct:
186 SafeSetFieldByIndex(f, fieldIndex[1:], src)
187 default:
188 }
189 }
190 return v
191 }
192
193 type rowData = map[string]interface{}
194
195
196 func AssignStructVals(i interface{}, rd rowData, cm ColumnMap) {
197 val := reflect.Indirect(reflect.ValueOf(i))
198
199 for name, data := range cm {
200 src, ok := rd[name]
201 if ok {
202 SafeSetFieldByIndex(val, data.FieldIndex, src)
203 }
204 }
205 }
206
207 func GetColumnMap(i interface{}) (ColumnMap, error) {
208 val := reflect.Indirect(reflect.ValueOf(i))
209 t, valKind := GetTypeInfo(i, val)
210 if valKind != reflect.Struct {
211 return nil, errors.New("cannot scan into this type: %v", t)
212 }
213
214 structMapCacheLock.Lock()
215 defer structMapCacheLock.Unlock()
216 if _, ok := structMapCache[t]; !ok {
217 structMapCache[t] = newColumnMap(t, []int{}, []string{})
218 }
219 return structMapCache[t], nil
220 }
221
View as plain text