...

Source file src/github.com/go-kivik/kivik/v4/mockdb/gen/render.go

Documentation: github.com/go-kivik/kivik/v4/mockdb/gen

     1  // Licensed under the Apache License, Version 2.0 (the "License"); you may not
     2  // use this file except in compliance with the License. You may obtain a copy of
     3  // the License at
     4  //
     5  //  http://www.apache.org/licenses/LICENSE-2.0
     6  //
     7  // Unless required by applicable law or agreed to in writing, software
     8  // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
     9  // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
    10  // License for the specific language governing permissions and limitations under
    11  // the License.
    12  
    13  package main
    14  
    15  import (
    16  	"bytes"
    17  	"fmt"
    18  	"os"
    19  	"reflect"
    20  	"strings"
    21  	"text/template"
    22  )
    23  
    24  var tmpl *template.Template
    25  
    26  func initTemplates(root string) {
    27  	var err error
    28  	tmpl, err = template.ParseGlob(root + "/*")
    29  	if err != nil {
    30  		panic(err)
    31  	}
    32  }
    33  
    34  func renderExpectationsGo(filename string, methods []*method) error {
    35  	file, err := os.Create(filename)
    36  	if err != nil {
    37  		return err
    38  	}
    39  	return tmpl.ExecuteTemplate(file, "expectations.go.tmpl", methods)
    40  }
    41  
    42  func renderClientGo(filename string, methods []*method) error {
    43  	file, err := os.Create(filename)
    44  	if err != nil {
    45  		return err
    46  	}
    47  	return tmpl.ExecuteTemplate(file, "client.go.tmpl", methods)
    48  }
    49  
    50  func renderMockGo(filename string, methods []*method) error {
    51  	file, err := os.Create(filename)
    52  	if err != nil {
    53  		return err
    54  	}
    55  	return tmpl.ExecuteTemplate(file, "mock.go.tmpl", methods)
    56  }
    57  
    58  func renderDriverMethod(m *method) (string, error) {
    59  	buf := &bytes.Buffer{}
    60  	err := tmpl.ExecuteTemplate(buf, "drivermethod.tmpl", m)
    61  	return buf.String(), err
    62  }
    63  
    64  func renderExpectedType(m *method) (string, error) {
    65  	buf := &bytes.Buffer{}
    66  	err := tmpl.ExecuteTemplate(buf, "expectedtype.tmpl", m)
    67  	return buf.String(), err
    68  }
    69  
    70  func (m *method) DriverArgs() string {
    71  	const extraCount = 2
    72  	args := make([]string, 0, len(m.Accepts)+extraCount)
    73  	if m.AcceptsContext {
    74  		args = append(args, "ctx context.Context")
    75  	}
    76  	for i, arg := range m.Accepts {
    77  		args = append(args, fmt.Sprintf("arg%d %s", i, typeName(arg)))
    78  	}
    79  	if m.AcceptsOptions {
    80  		args = append(args, "options driver.Options")
    81  	}
    82  	return strings.Join(args, ", ")
    83  }
    84  
    85  func (m *method) ReturnArgs() string {
    86  	args := make([]string, 0, len(m.Returns)+1)
    87  	for _, arg := range m.Returns {
    88  		args = append(args, arg.String())
    89  	}
    90  	if m.ReturnsError {
    91  		args = append(args, "error")
    92  	}
    93  	if len(args) > 1 {
    94  		return `(` + strings.Join(args, ", ") + `)`
    95  	}
    96  	return args[0]
    97  }
    98  
    99  func (m *method) VariableDefinitions() string {
   100  	result := make([]string, 0, len(m.Accepts)+len(m.Returns))
   101  	for i, arg := range m.Accepts {
   102  		result = append(result, fmt.Sprintf("\targ%d %s\n", i, typeName(arg)))
   103  	}
   104  	for i, ret := range m.Returns {
   105  		name := typeName(ret)
   106  		switch name {
   107  		case "driver.DB": // nolint: goconst
   108  			name = "*DB"
   109  		case "driver.Replication": // nolint: goconst
   110  			name = "*Replication"
   111  		case "[]driver.Replication": // nolint: goconst
   112  			name = "[]*Replication"
   113  		}
   114  		result = append(result, fmt.Sprintf("\tret%d %s\n", i, name))
   115  	}
   116  	return strings.Join(result, "")
   117  }
   118  
   119  func (m *method) inputVars() []string {
   120  	args := make([]string, 0, len(m.Accepts)+1)
   121  	for i := range m.Accepts {
   122  		args = append(args, fmt.Sprintf("arg%d", i))
   123  	}
   124  	if m.AcceptsOptions {
   125  		args = append(args, "options")
   126  	}
   127  	return args
   128  }
   129  
   130  func (m *method) ExpectedVariables() string {
   131  	args := []string{}
   132  	if m.DBMethod {
   133  		args = append(args, "db")
   134  	}
   135  	args = append(args, m.inputVars()...)
   136  	return alignVars(0, args)
   137  }
   138  
   139  func (m *method) InputVariables() string {
   140  	result := make([]string, len(m.Accepts)+1)
   141  	var common []string
   142  	if m.DBMethod {
   143  		common = append(common, "\t\t\tdb: db.DB,\n")
   144  	}
   145  	for i := range m.Accepts {
   146  		result = append(result, fmt.Sprintf("\t\targ%d: arg%d,\n", i, i))
   147  	}
   148  	if m.AcceptsOptions {
   149  		common = append(common, "\t\t\toptions: options,\n")
   150  	}
   151  	if len(common) > 0 {
   152  		result = append(result, fmt.Sprintf("\t\tcommonExpectation: commonExpectation{\n%s\t\t},\n",
   153  			strings.Join(common, "")))
   154  	}
   155  	return strings.Join(result, "")
   156  }
   157  
   158  func (m *method) Variables(indent int) string {
   159  	args := m.inputVars()
   160  	for i := range m.Returns {
   161  		args = append(args, fmt.Sprintf("ret%d", i))
   162  	}
   163  	return alignVars(indent, args)
   164  }
   165  
   166  func alignVars(indent int, args []string) string {
   167  	var maxLen int
   168  	for _, arg := range args {
   169  		if l := len(arg); l > maxLen {
   170  			maxLen = l
   171  		}
   172  	}
   173  	final := make([]string, len(args))
   174  	for i, arg := range args {
   175  		final[i] = fmt.Sprintf("%s%*s %s,", strings.Repeat("\t", indent), -(maxLen + 1), arg+":", arg)
   176  	}
   177  	return strings.Join(final, "\n")
   178  }
   179  
   180  func (m *method) ZeroReturns() string {
   181  	args := make([]string, 0, len(m.Returns))
   182  	for _, arg := range m.Returns {
   183  		args = append(args, zeroValue(arg))
   184  	}
   185  	args = append(args, "err")
   186  	return strings.Join(args, ", ")
   187  }
   188  
   189  func zeroValue(t reflect.Type) string {
   190  	z := fmt.Sprintf("%#v", reflect.Zero(t).Interface())
   191  	if strings.HasSuffix(z, "(nil)") {
   192  		return "nil"
   193  	}
   194  	if z == "<nil>" {
   195  		return "nil"
   196  	}
   197  	return z
   198  }
   199  
   200  func (m *method) ExpectedReturns() string {
   201  	args := make([]string, 0, len(m.Returns))
   202  	for i, arg := range m.Returns {
   203  		switch arg.String() {
   204  		case "driver.Rows":
   205  			args = append(args, fmt.Sprintf("&driverRows{Context: ctx, Rows: coalesceRows(expected.ret%d)}", i))
   206  		case "driver.Changes":
   207  			args = append(args, fmt.Sprintf("&driverChanges{Context: ctx, Changes: coalesceChanges(expected.ret%d)}", i))
   208  		case "driver.DB":
   209  			args = append(args, fmt.Sprintf("&driverDB{DB: expected.ret%d}", i))
   210  		case "driver.DBUpdates":
   211  			args = append(args, fmt.Sprintf("&driverDBUpdates{Context:ctx, Updates: coalesceDBUpdates(expected.ret%d)}", i))
   212  		case "driver.Replication":
   213  			args = append(args, fmt.Sprintf("&driverReplication{Replication: expected.ret%d}", i))
   214  		case "[]driver.Replication":
   215  			args = append(args, fmt.Sprintf("driverReplications(expected.ret%d)", i))
   216  		default:
   217  			args = append(args, fmt.Sprintf("expected.ret%d", i))
   218  		}
   219  	}
   220  	if m.AcceptsContext {
   221  		args = append(args, "expected.wait(ctx)")
   222  	} else {
   223  		args = append(args, "expected.err")
   224  	}
   225  	return strings.Join(args, ", ")
   226  }
   227  
   228  func (m *method) ReturnTypes() string {
   229  	args := make([]string, len(m.Returns))
   230  	for i, ret := range m.Returns {
   231  		name := typeName(ret)
   232  		switch name {
   233  		case "driver.DB":
   234  			name = "*DB"
   235  		case "driver.Replication":
   236  			name = "*Replication"
   237  		case "[]driver.Replication":
   238  			name = "[]*Replication"
   239  		}
   240  		args[i] = fmt.Sprintf("ret%d %s", i, name)
   241  	}
   242  	return strings.Join(args, ", ")
   243  }
   244  
   245  func typeName(t reflect.Type) string {
   246  	name := t.String()
   247  	switch name {
   248  	case "interface {}":
   249  		return "interface{}"
   250  	case "driver.Rows":
   251  		return "*Rows"
   252  	case "driver.Changes":
   253  		return "*Changes"
   254  	case "driver.DBUpdates":
   255  		return "*Updates"
   256  	}
   257  	return name
   258  }
   259  
   260  func (m *method) SetExpectations() string {
   261  	var args []string
   262  	if m.DBMethod {
   263  		args = append(args, "commonExpectation: commonExpectation{db: db},\n")
   264  	}
   265  	if m.Name == "DB" {
   266  		args = append(args, "ret0: &DB{},\n")
   267  	}
   268  	for i, ret := range m.Returns {
   269  		var zero string
   270  		switch ret.String() {
   271  		case "*kivik.Rows":
   272  			zero = "&Rows{}"
   273  		case "*kivik.QueryPlan":
   274  			zero = "&driver.QueryPlan{}"
   275  		case "*kivik.PurgeResult":
   276  			zero = "&driver.PurgeResult{}"
   277  		case "*kivik.DBUpdates":
   278  			zero = "&Updates{}"
   279  		}
   280  		if zero != "" {
   281  			args = append(args, fmt.Sprintf("ret%d: %s,\n", i, zero))
   282  		}
   283  	}
   284  	return strings.Join(args, "")
   285  }
   286  
   287  func (m *method) MetExpectations() string {
   288  	if len(m.Accepts) == 0 {
   289  		return ""
   290  	}
   291  	args := make([]string, 0, len(m.Accepts)+1)
   292  	args = append(args, fmt.Sprintf("\texp := ex.(*Expected%s)", m.Name))
   293  	var check string
   294  	for i, arg := range m.Accepts {
   295  		switch arg.String() {
   296  		case "string":
   297  			check = `exp.arg%[1]d != "" && exp.arg%[1]d != e.arg%[1]d`
   298  		case "int":
   299  			check = "exp.arg%[1]d != 0 && exp.arg%[1]d != e.arg%[1]d"
   300  		case "interface {}":
   301  			check = "exp.arg%[1]d != nil && !jsonMeets(exp.arg%[1]d, e.arg%[1]d)"
   302  		default:
   303  			check = "exp.arg%[1]d != nil && !reflect.DeepEqual(exp.arg%[1]d, e.arg%[1]d)"
   304  		}
   305  		args = append(args, fmt.Sprintf("if "+check+" {\n\t\treturn false\n\t}", i))
   306  	}
   307  	return strings.Join(args, "\n")
   308  }
   309  
   310  func (m *method) MethodArgs() string {
   311  	str := make([]string, 0, len(m.Accepts)+1)
   312  	def := make([]string, 0, len(m.Accepts)+1)
   313  	const maxVarLen = 3
   314  	vars := make([]string, 0, maxVarLen)
   315  	var args, mid []string
   316  	prefix := ""
   317  	if m.DBMethod {
   318  		prefix = "DB(%s)."
   319  		args = append(args, "e.dbo().name")
   320  	}
   321  	if m.AcceptsContext {
   322  		vars = append(vars, "ctx")
   323  	}
   324  	var lines []string
   325  	for i, acc := range m.Accepts {
   326  		str = append(str, fmt.Sprintf("arg%d", i))
   327  		def = append(def, `"?"`)
   328  		vars = append(vars, "%s")
   329  		switch acc.String() {
   330  		case "string":
   331  			mid = append(mid, fmt.Sprintf(`	if e.arg%[1]d != "" { arg%[1]d = fmt.Sprintf("%%q", e.arg%[1]d)}`, i))
   332  		case "int":
   333  			mid = append(mid, fmt.Sprintf(`	if e.arg%[1]d != 0 { arg%[1]d = fmt.Sprintf("%%q", e.arg%[1]d)}`, i))
   334  		default:
   335  			mid = append(mid, fmt.Sprintf(`	if e.arg%[1]d != nil { arg%[1]d = fmt.Sprintf("%%v", e.arg%[1]d) }`, i))
   336  		}
   337  	}
   338  	if m.AcceptsOptions {
   339  		str = append(str, "options")
   340  		def = append(def, `formatOptions(e.options)`)
   341  		vars = append(vars, "%s")
   342  	}
   343  	if len(str) > 0 {
   344  		lines = append(lines, fmt.Sprintf("\t%s := %s", strings.Join(str, ", "), strings.Join(def, ", ")))
   345  	}
   346  	lines = append(lines, mid...)
   347  	lines = append(lines, fmt.Sprintf("\treturn fmt.Sprintf(\"%s%s(%s)\", %s)", prefix, m.Name, strings.Join(vars, ", "), strings.Join(append(args, str...), ", ")))
   348  	return strings.Join(lines, "\n")
   349  }
   350  
   351  // CallbackType returns the type definition for a callback for this method.
   352  func (m *method) CallbackTypes() string {
   353  	const extraCount = 2
   354  	inputs := make([]string, 0, len(m.Accepts)+extraCount)
   355  	if m.AcceptsContext {
   356  		inputs = append(inputs, "context.Context")
   357  	}
   358  	for _, arg := range m.Accepts {
   359  		inputs = append(inputs, typeName(arg))
   360  	}
   361  	if m.AcceptsOptions {
   362  		inputs = append(inputs, "driver.Options")
   363  	}
   364  	return strings.Join(inputs, ", ")
   365  }
   366  
   367  // CallbackArgs returns the list of arguments to be passed to the callback
   368  func (m *method) CallbackArgs() string {
   369  	const extraCount = 2
   370  	args := make([]string, 0, len(m.Accepts)+extraCount)
   371  	if m.AcceptsContext {
   372  		args = append(args, "ctx")
   373  	}
   374  	for i := range m.Accepts {
   375  		args = append(args, fmt.Sprintf("arg%d", i))
   376  	}
   377  	if m.AcceptsOptions {
   378  		args = append(args, "options")
   379  	}
   380  	return strings.Join(args, ", ")
   381  }
   382  
   383  func (m *method) CallbackReturns() string {
   384  	args := make([]string, 0, len(m.Returns)+1)
   385  	for _, ret := range m.Returns {
   386  		args = append(args, ret.String())
   387  	}
   388  	if m.ReturnsError {
   389  		args = append(args, "error")
   390  	}
   391  	if len(args) > 1 {
   392  		return "(" + strings.Join(args, ", ") + ")"
   393  	}
   394  	return strings.Join(args, ", ")
   395  }
   396  

View as plain text