1 package sqlmock
2
3 import (
4 "bytes"
5 "database/sql/driver"
6 "encoding/csv"
7 "errors"
8 "fmt"
9 "io"
10 "strings"
11 )
12
13 const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ "
14
15
16
17
18 var CSVColumnParser = func(s string) interface{} {
19 switch {
20 case strings.ToLower(s) == "null":
21 return nil
22 }
23 return []byte(s)
24 }
25
26 type rowSets struct {
27 sets []*Rows
28 pos int
29 ex *ExpectedQuery
30 raw [][]byte
31 }
32
33 func (rs *rowSets) Columns() []string {
34 return rs.sets[rs.pos].cols
35 }
36
37 func (rs *rowSets) Close() error {
38 rs.invalidateRaw()
39 rs.ex.rowsWereClosed = true
40 return rs.sets[rs.pos].closeErr
41 }
42
43
44 func (rs *rowSets) Next(dest []driver.Value) error {
45 r := rs.sets[rs.pos]
46 r.pos++
47 rs.invalidateRaw()
48 if r.pos > len(r.rows) {
49 return io.EOF
50 }
51
52 for i, col := range r.rows[r.pos-1] {
53 if b, ok := rawBytes(col); ok {
54 rs.raw = append(rs.raw, b)
55 dest[i] = b
56 continue
57 }
58 dest[i] = col
59 }
60
61 return r.nextErr[r.pos-1]
62 }
63
64
65 func (rs *rowSets) String() string {
66 if rs.empty() {
67 return "with empty rows"
68 }
69
70 msg := "should return rows:\n"
71 if len(rs.sets) == 1 {
72 for n, row := range rs.sets[0].rows {
73 msg += fmt.Sprintf(" row %d - %+v\n", n, row)
74 }
75 return strings.TrimSpace(msg)
76 }
77 for i, set := range rs.sets {
78 msg += fmt.Sprintf(" result set: %d\n", i)
79 for n, row := range set.rows {
80 msg += fmt.Sprintf(" row %d - %+v\n", n, row)
81 }
82 }
83 return strings.TrimSpace(msg)
84 }
85
86 func (rs *rowSets) empty() bool {
87 for _, set := range rs.sets {
88 if len(set.rows) > 0 {
89 return false
90 }
91 }
92 return true
93 }
94
95 func rawBytes(col driver.Value) (_ []byte, ok bool) {
96 val, ok := col.([]byte)
97 if !ok || len(val) == 0 {
98 return nil, false
99 }
100
101
102 b := make([]byte, len(val))
103 copy(b, val)
104 return b, true
105 }
106
107
108
109 func (rs *rowSets) invalidateRaw() {
110
111 b := []byte(invalidate)
112 for _, r := range rs.raw {
113 copy(r, bytes.Repeat(b, len(r)/len(b)+1))
114 }
115
116 rs.raw = nil
117 }
118
119
120
121 type Rows struct {
122 converter driver.ValueConverter
123 cols []string
124 def []*Column
125 rows [][]driver.Value
126 pos int
127 nextErr map[int]error
128 closeErr error
129 }
130
131
132
133
134
135 func NewRows(columns []string) *Rows {
136 return &Rows{
137 cols: columns,
138 nextErr: make(map[int]error),
139 converter: driver.DefaultParameterConverter,
140 }
141 }
142
143
144
145
146
147
148
149
150 func (r *Rows) CloseError(err error) *Rows {
151 r.closeErr = err
152 return r
153 }
154
155
156
157
158 func (r *Rows) RowError(row int, err error) *Rows {
159 r.nextErr[row] = err
160 return r
161 }
162
163
164
165
166
167 func (r *Rows) AddRow(values ...driver.Value) *Rows {
168 if len(values) != len(r.cols) {
169 panic(fmt.Sprintf("Expected number of values to match number of columns: expected %d, actual %d", len(values), len(r.cols)))
170 }
171
172 row := make([]driver.Value, len(r.cols))
173 for i, v := range values {
174
175
176 var err error
177 v, err = r.converter.ConvertValue(v)
178 if err != nil {
179 panic(fmt.Errorf(
180 "row #%d, column #%d (%q) type %T: %s",
181 len(r.rows)+1, i, r.cols[i], values[i], err,
182 ))
183 }
184
185 row[i] = v
186 }
187
188 r.rows = append(r.rows, row)
189 return r
190 }
191
192
193
194 func (r *Rows) AddRows(values ...[]driver.Value) *Rows {
195 for _, value := range values {
196 r.AddRow(value...)
197 }
198
199 return r
200 }
201
202
203
204
205
206 func (r *Rows) FromCSVString(s string) *Rows {
207 res := strings.NewReader(strings.TrimSpace(s))
208 csvReader := csv.NewReader(res)
209
210 for {
211 res, err := csvReader.Read()
212 if err != nil {
213 if errors.Is(err, io.EOF) {
214 break
215 }
216 panic(fmt.Sprintf("Parsing CSV string failed: %s", err.Error()))
217 }
218
219 row := make([]driver.Value, len(r.cols))
220 for i, v := range res {
221 row[i] = CSVColumnParser(strings.TrimSpace(v))
222 }
223 r.rows = append(r.rows, row)
224 }
225 return r
226 }
227
View as plain text