1
2
3
4
5
6
7
8
9
10
11
12
13 package main
14
15 import (
16 "bytes"
17 "fmt"
18 "os"
19 "reflect"
20 "strings"
21 "text/template"
22 )
23
24 var tmpl *template.Template
25
26 func initTemplates(root string) {
27 var err error
28 tmpl, err = template.ParseGlob(root + "/*")
29 if err != nil {
30 panic(err)
31 }
32 }
33
34 func renderExpectationsGo(filename string, methods []*method) error {
35 file, err := os.Create(filename)
36 if err != nil {
37 return err
38 }
39 return tmpl.ExecuteTemplate(file, "expectations.go.tmpl", methods)
40 }
41
42 func renderClientGo(filename string, methods []*method) error {
43 file, err := os.Create(filename)
44 if err != nil {
45 return err
46 }
47 return tmpl.ExecuteTemplate(file, "client.go.tmpl", methods)
48 }
49
50 func renderMockGo(filename string, methods []*method) error {
51 file, err := os.Create(filename)
52 if err != nil {
53 return err
54 }
55 return tmpl.ExecuteTemplate(file, "mock.go.tmpl", methods)
56 }
57
58 func renderDriverMethod(m *method) (string, error) {
59 buf := &bytes.Buffer{}
60 err := tmpl.ExecuteTemplate(buf, "drivermethod.tmpl", m)
61 return buf.String(), err
62 }
63
64 func renderExpectedType(m *method) (string, error) {
65 buf := &bytes.Buffer{}
66 err := tmpl.ExecuteTemplate(buf, "expectedtype.tmpl", m)
67 return buf.String(), err
68 }
69
70 func (m *method) DriverArgs() string {
71 const extraCount = 2
72 args := make([]string, 0, len(m.Accepts)+extraCount)
73 if m.AcceptsContext {
74 args = append(args, "ctx context.Context")
75 }
76 for i, arg := range m.Accepts {
77 args = append(args, fmt.Sprintf("arg%d %s", i, typeName(arg)))
78 }
79 if m.AcceptsOptions {
80 args = append(args, "options driver.Options")
81 }
82 return strings.Join(args, ", ")
83 }
84
85 func (m *method) ReturnArgs() string {
86 args := make([]string, 0, len(m.Returns)+1)
87 for _, arg := range m.Returns {
88 args = append(args, arg.String())
89 }
90 if m.ReturnsError {
91 args = append(args, "error")
92 }
93 if len(args) > 1 {
94 return `(` + strings.Join(args, ", ") + `)`
95 }
96 return args[0]
97 }
98
99 func (m *method) VariableDefinitions() string {
100 result := make([]string, 0, len(m.Accepts)+len(m.Returns))
101 for i, arg := range m.Accepts {
102 result = append(result, fmt.Sprintf("\targ%d %s\n", i, typeName(arg)))
103 }
104 for i, ret := range m.Returns {
105 name := typeName(ret)
106 switch name {
107 case "driver.DB":
108 name = "*DB"
109 case "driver.Replication":
110 name = "*Replication"
111 case "[]driver.Replication":
112 name = "[]*Replication"
113 }
114 result = append(result, fmt.Sprintf("\tret%d %s\n", i, name))
115 }
116 return strings.Join(result, "")
117 }
118
119 func (m *method) inputVars() []string {
120 args := make([]string, 0, len(m.Accepts)+1)
121 for i := range m.Accepts {
122 args = append(args, fmt.Sprintf("arg%d", i))
123 }
124 if m.AcceptsOptions {
125 args = append(args, "options")
126 }
127 return args
128 }
129
130 func (m *method) ExpectedVariables() string {
131 args := []string{}
132 if m.DBMethod {
133 args = append(args, "db")
134 }
135 args = append(args, m.inputVars()...)
136 return alignVars(0, args)
137 }
138
139 func (m *method) InputVariables() string {
140 result := make([]string, len(m.Accepts)+1)
141 var common []string
142 if m.DBMethod {
143 common = append(common, "\t\t\tdb: db.DB,\n")
144 }
145 for i := range m.Accepts {
146 result = append(result, fmt.Sprintf("\t\targ%d: arg%d,\n", i, i))
147 }
148 if m.AcceptsOptions {
149 common = append(common, "\t\t\toptions: options,\n")
150 }
151 if len(common) > 0 {
152 result = append(result, fmt.Sprintf("\t\tcommonExpectation: commonExpectation{\n%s\t\t},\n",
153 strings.Join(common, "")))
154 }
155 return strings.Join(result, "")
156 }
157
158 func (m *method) Variables(indent int) string {
159 args := m.inputVars()
160 for i := range m.Returns {
161 args = append(args, fmt.Sprintf("ret%d", i))
162 }
163 return alignVars(indent, args)
164 }
165
166 func alignVars(indent int, args []string) string {
167 var maxLen int
168 for _, arg := range args {
169 if l := len(arg); l > maxLen {
170 maxLen = l
171 }
172 }
173 final := make([]string, len(args))
174 for i, arg := range args {
175 final[i] = fmt.Sprintf("%s%*s %s,", strings.Repeat("\t", indent), -(maxLen + 1), arg+":", arg)
176 }
177 return strings.Join(final, "\n")
178 }
179
180 func (m *method) ZeroReturns() string {
181 args := make([]string, 0, len(m.Returns))
182 for _, arg := range m.Returns {
183 args = append(args, zeroValue(arg))
184 }
185 args = append(args, "err")
186 return strings.Join(args, ", ")
187 }
188
189 func zeroValue(t reflect.Type) string {
190 z := fmt.Sprintf("%#v", reflect.Zero(t).Interface())
191 if strings.HasSuffix(z, "(nil)") {
192 return "nil"
193 }
194 if z == "<nil>" {
195 return "nil"
196 }
197 return z
198 }
199
200 func (m *method) ExpectedReturns() string {
201 args := make([]string, 0, len(m.Returns))
202 for i, arg := range m.Returns {
203 switch arg.String() {
204 case "driver.Rows":
205 args = append(args, fmt.Sprintf("&driverRows{Context: ctx, Rows: coalesceRows(expected.ret%d)}", i))
206 case "driver.Changes":
207 args = append(args, fmt.Sprintf("&driverChanges{Context: ctx, Changes: coalesceChanges(expected.ret%d)}", i))
208 case "driver.DB":
209 args = append(args, fmt.Sprintf("&driverDB{DB: expected.ret%d}", i))
210 case "driver.DBUpdates":
211 args = append(args, fmt.Sprintf("&driverDBUpdates{Context:ctx, Updates: coalesceDBUpdates(expected.ret%d)}", i))
212 case "driver.Replication":
213 args = append(args, fmt.Sprintf("&driverReplication{Replication: expected.ret%d}", i))
214 case "[]driver.Replication":
215 args = append(args, fmt.Sprintf("driverReplications(expected.ret%d)", i))
216 default:
217 args = append(args, fmt.Sprintf("expected.ret%d", i))
218 }
219 }
220 if m.AcceptsContext {
221 args = append(args, "expected.wait(ctx)")
222 } else {
223 args = append(args, "expected.err")
224 }
225 return strings.Join(args, ", ")
226 }
227
228 func (m *method) ReturnTypes() string {
229 args := make([]string, len(m.Returns))
230 for i, ret := range m.Returns {
231 name := typeName(ret)
232 switch name {
233 case "driver.DB":
234 name = "*DB"
235 case "driver.Replication":
236 name = "*Replication"
237 case "[]driver.Replication":
238 name = "[]*Replication"
239 }
240 args[i] = fmt.Sprintf("ret%d %s", i, name)
241 }
242 return strings.Join(args, ", ")
243 }
244
245 func typeName(t reflect.Type) string {
246 name := t.String()
247 switch name {
248 case "interface {}":
249 return "interface{}"
250 case "driver.Rows":
251 return "*Rows"
252 case "driver.Changes":
253 return "*Changes"
254 case "driver.DBUpdates":
255 return "*Updates"
256 }
257 return name
258 }
259
260 func (m *method) SetExpectations() string {
261 var args []string
262 if m.DBMethod {
263 args = append(args, "commonExpectation: commonExpectation{db: db},\n")
264 }
265 if m.Name == "DB" {
266 args = append(args, "ret0: &DB{},\n")
267 }
268 for i, ret := range m.Returns {
269 var zero string
270 switch ret.String() {
271 case "*kivik.Rows":
272 zero = "&Rows{}"
273 case "*kivik.QueryPlan":
274 zero = "&driver.QueryPlan{}"
275 case "*kivik.PurgeResult":
276 zero = "&driver.PurgeResult{}"
277 case "*kivik.DBUpdates":
278 zero = "&Updates{}"
279 }
280 if zero != "" {
281 args = append(args, fmt.Sprintf("ret%d: %s,\n", i, zero))
282 }
283 }
284 return strings.Join(args, "")
285 }
286
287 func (m *method) MetExpectations() string {
288 if len(m.Accepts) == 0 {
289 return ""
290 }
291 args := make([]string, 0, len(m.Accepts)+1)
292 args = append(args, fmt.Sprintf("\texp := ex.(*Expected%s)", m.Name))
293 var check string
294 for i, arg := range m.Accepts {
295 switch arg.String() {
296 case "string":
297 check = `exp.arg%[1]d != "" && exp.arg%[1]d != e.arg%[1]d`
298 case "int":
299 check = "exp.arg%[1]d != 0 && exp.arg%[1]d != e.arg%[1]d"
300 case "interface {}":
301 check = "exp.arg%[1]d != nil && !jsonMeets(exp.arg%[1]d, e.arg%[1]d)"
302 default:
303 check = "exp.arg%[1]d != nil && !reflect.DeepEqual(exp.arg%[1]d, e.arg%[1]d)"
304 }
305 args = append(args, fmt.Sprintf("if "+check+" {\n\t\treturn false\n\t}", i))
306 }
307 return strings.Join(args, "\n")
308 }
309
310 func (m *method) MethodArgs() string {
311 str := make([]string, 0, len(m.Accepts)+1)
312 def := make([]string, 0, len(m.Accepts)+1)
313 const maxVarLen = 3
314 vars := make([]string, 0, maxVarLen)
315 var args, mid []string
316 prefix := ""
317 if m.DBMethod {
318 prefix = "DB(%s)."
319 args = append(args, "e.dbo().name")
320 }
321 if m.AcceptsContext {
322 vars = append(vars, "ctx")
323 }
324 var lines []string
325 for i, acc := range m.Accepts {
326 str = append(str, fmt.Sprintf("arg%d", i))
327 def = append(def, `"?"`)
328 vars = append(vars, "%s")
329 switch acc.String() {
330 case "string":
331 mid = append(mid, fmt.Sprintf(` if e.arg%[1]d != "" { arg%[1]d = fmt.Sprintf("%%q", e.arg%[1]d)}`, i))
332 case "int":
333 mid = append(mid, fmt.Sprintf(` if e.arg%[1]d != 0 { arg%[1]d = fmt.Sprintf("%%q", e.arg%[1]d)}`, i))
334 default:
335 mid = append(mid, fmt.Sprintf(` if e.arg%[1]d != nil { arg%[1]d = fmt.Sprintf("%%v", e.arg%[1]d) }`, i))
336 }
337 }
338 if m.AcceptsOptions {
339 str = append(str, "options")
340 def = append(def, `formatOptions(e.options)`)
341 vars = append(vars, "%s")
342 }
343 if len(str) > 0 {
344 lines = append(lines, fmt.Sprintf("\t%s := %s", strings.Join(str, ", "), strings.Join(def, ", ")))
345 }
346 lines = append(lines, mid...)
347 lines = append(lines, fmt.Sprintf("\treturn fmt.Sprintf(\"%s%s(%s)\", %s)", prefix, m.Name, strings.Join(vars, ", "), strings.Join(append(args, str...), ", ")))
348 return strings.Join(lines, "\n")
349 }
350
351
352 func (m *method) CallbackTypes() string {
353 const extraCount = 2
354 inputs := make([]string, 0, len(m.Accepts)+extraCount)
355 if m.AcceptsContext {
356 inputs = append(inputs, "context.Context")
357 }
358 for _, arg := range m.Accepts {
359 inputs = append(inputs, typeName(arg))
360 }
361 if m.AcceptsOptions {
362 inputs = append(inputs, "driver.Options")
363 }
364 return strings.Join(inputs, ", ")
365 }
366
367
368 func (m *method) CallbackArgs() string {
369 const extraCount = 2
370 args := make([]string, 0, len(m.Accepts)+extraCount)
371 if m.AcceptsContext {
372 args = append(args, "ctx")
373 }
374 for i := range m.Accepts {
375 args = append(args, fmt.Sprintf("arg%d", i))
376 }
377 if m.AcceptsOptions {
378 args = append(args, "options")
379 }
380 return strings.Join(args, ", ")
381 }
382
383 func (m *method) CallbackReturns() string {
384 args := make([]string, 0, len(m.Returns)+1)
385 for _, ret := range m.Returns {
386 args = append(args, ret.String())
387 }
388 if m.ReturnsError {
389 args = append(args, "error")
390 }
391 if len(args) > 1 {
392 return "(" + strings.Join(args, ", ") + ")"
393 }
394 return strings.Join(args, ", ")
395 }
396
View as plain text