...

Source file src/github.com/golang/mock/mockgen/model/model.go

Documentation: github.com/golang/mock/mockgen/model

     1  // Copyright 2012 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package model contains the data model necessary for generating mock implementations.
    16  package model
    17  
    18  import (
    19  	"encoding/gob"
    20  	"fmt"
    21  	"io"
    22  	"reflect"
    23  	"strings"
    24  )
    25  
    26  // pkgPath is the importable path for package model
    27  const pkgPath = "github.com/golang/mock/mockgen/model"
    28  
    29  // Package is a Go package. It may be a subset.
    30  type Package struct {
    31  	Name       string
    32  	PkgPath    string
    33  	Interfaces []*Interface
    34  	DotImports []string
    35  }
    36  
    37  // Print writes the package name and its exported interfaces.
    38  func (pkg *Package) Print(w io.Writer) {
    39  	_, _ = fmt.Fprintf(w, "package %s\n", pkg.Name)
    40  	for _, intf := range pkg.Interfaces {
    41  		intf.Print(w)
    42  	}
    43  }
    44  
    45  // Imports returns the imports needed by the Package as a set of import paths.
    46  func (pkg *Package) Imports() map[string]bool {
    47  	im := make(map[string]bool)
    48  	for _, intf := range pkg.Interfaces {
    49  		intf.addImports(im)
    50  		for _, tp := range intf.TypeParams {
    51  			tp.Type.addImports(im)
    52  		}
    53  	}
    54  	return im
    55  }
    56  
    57  // Interface is a Go interface.
    58  type Interface struct {
    59  	Name       string
    60  	Methods    []*Method
    61  	TypeParams []*Parameter
    62  }
    63  
    64  // Print writes the interface name and its methods.
    65  func (intf *Interface) Print(w io.Writer) {
    66  	_, _ = fmt.Fprintf(w, "interface %s\n", intf.Name)
    67  	for _, m := range intf.Methods {
    68  		m.Print(w)
    69  	}
    70  }
    71  
    72  func (intf *Interface) addImports(im map[string]bool) {
    73  	for _, m := range intf.Methods {
    74  		m.addImports(im)
    75  	}
    76  }
    77  
    78  // AddMethod adds a new method, de-duplicating by method name.
    79  func (intf *Interface) AddMethod(m *Method) {
    80  	for _, me := range intf.Methods {
    81  		if me.Name == m.Name {
    82  			return
    83  		}
    84  	}
    85  	intf.Methods = append(intf.Methods, m)
    86  }
    87  
    88  // Method is a single method of an interface.
    89  type Method struct {
    90  	Name     string
    91  	In, Out  []*Parameter
    92  	Variadic *Parameter // may be nil
    93  }
    94  
    95  // Print writes the method name and its signature.
    96  func (m *Method) Print(w io.Writer) {
    97  	_, _ = fmt.Fprintf(w, "  - method %s\n", m.Name)
    98  	if len(m.In) > 0 {
    99  		_, _ = fmt.Fprintf(w, "    in:\n")
   100  		for _, p := range m.In {
   101  			p.Print(w)
   102  		}
   103  	}
   104  	if m.Variadic != nil {
   105  		_, _ = fmt.Fprintf(w, "    ...:\n")
   106  		m.Variadic.Print(w)
   107  	}
   108  	if len(m.Out) > 0 {
   109  		_, _ = fmt.Fprintf(w, "    out:\n")
   110  		for _, p := range m.Out {
   111  			p.Print(w)
   112  		}
   113  	}
   114  }
   115  
   116  func (m *Method) addImports(im map[string]bool) {
   117  	for _, p := range m.In {
   118  		p.Type.addImports(im)
   119  	}
   120  	if m.Variadic != nil {
   121  		m.Variadic.Type.addImports(im)
   122  	}
   123  	for _, p := range m.Out {
   124  		p.Type.addImports(im)
   125  	}
   126  }
   127  
   128  // Parameter is an argument or return parameter of a method.
   129  type Parameter struct {
   130  	Name string // may be empty
   131  	Type Type
   132  }
   133  
   134  // Print writes a method parameter.
   135  func (p *Parameter) Print(w io.Writer) {
   136  	n := p.Name
   137  	if n == "" {
   138  		n = `""`
   139  	}
   140  	_, _ = fmt.Fprintf(w, "    - %v: %v\n", n, p.Type.String(nil, ""))
   141  }
   142  
   143  // Type is a Go type.
   144  type Type interface {
   145  	String(pm map[string]string, pkgOverride string) string
   146  	addImports(im map[string]bool)
   147  }
   148  
   149  func init() {
   150  	gob.Register(&ArrayType{})
   151  	gob.Register(&ChanType{})
   152  	gob.Register(&FuncType{})
   153  	gob.Register(&MapType{})
   154  	gob.Register(&NamedType{})
   155  	gob.Register(&PointerType{})
   156  
   157  	// Call gob.RegisterName to make sure it has the consistent name registered
   158  	// for both gob decoder and encoder.
   159  	//
   160  	// For a non-pointer type, gob.Register will try to get package full path by
   161  	// calling rt.PkgPath() for a name to register. If your project has vendor
   162  	// directory, it is possible that PkgPath will get a path like this:
   163  	//     ../../../vendor/github.com/golang/mock/mockgen/model
   164  	gob.RegisterName(pkgPath+".PredeclaredType", PredeclaredType(""))
   165  }
   166  
   167  // ArrayType is an array or slice type.
   168  type ArrayType struct {
   169  	Len  int // -1 for slices, >= 0 for arrays
   170  	Type Type
   171  }
   172  
   173  func (at *ArrayType) String(pm map[string]string, pkgOverride string) string {
   174  	s := "[]"
   175  	if at.Len > -1 {
   176  		s = fmt.Sprintf("[%d]", at.Len)
   177  	}
   178  	return s + at.Type.String(pm, pkgOverride)
   179  }
   180  
   181  func (at *ArrayType) addImports(im map[string]bool) { at.Type.addImports(im) }
   182  
   183  // ChanType is a channel type.
   184  type ChanType struct {
   185  	Dir  ChanDir // 0, 1 or 2
   186  	Type Type
   187  }
   188  
   189  func (ct *ChanType) String(pm map[string]string, pkgOverride string) string {
   190  	s := ct.Type.String(pm, pkgOverride)
   191  	if ct.Dir == RecvDir {
   192  		return "<-chan " + s
   193  	}
   194  	if ct.Dir == SendDir {
   195  		return "chan<- " + s
   196  	}
   197  	return "chan " + s
   198  }
   199  
   200  func (ct *ChanType) addImports(im map[string]bool) { ct.Type.addImports(im) }
   201  
   202  // ChanDir is a channel direction.
   203  type ChanDir int
   204  
   205  // Constants for channel directions.
   206  const (
   207  	RecvDir ChanDir = 1
   208  	SendDir ChanDir = 2
   209  )
   210  
   211  // FuncType is a function type.
   212  type FuncType struct {
   213  	In, Out  []*Parameter
   214  	Variadic *Parameter // may be nil
   215  }
   216  
   217  func (ft *FuncType) String(pm map[string]string, pkgOverride string) string {
   218  	args := make([]string, len(ft.In))
   219  	for i, p := range ft.In {
   220  		args[i] = p.Type.String(pm, pkgOverride)
   221  	}
   222  	if ft.Variadic != nil {
   223  		args = append(args, "..."+ft.Variadic.Type.String(pm, pkgOverride))
   224  	}
   225  	rets := make([]string, len(ft.Out))
   226  	for i, p := range ft.Out {
   227  		rets[i] = p.Type.String(pm, pkgOverride)
   228  	}
   229  	retString := strings.Join(rets, ", ")
   230  	if nOut := len(ft.Out); nOut == 1 {
   231  		retString = " " + retString
   232  	} else if nOut > 1 {
   233  		retString = " (" + retString + ")"
   234  	}
   235  	return "func(" + strings.Join(args, ", ") + ")" + retString
   236  }
   237  
   238  func (ft *FuncType) addImports(im map[string]bool) {
   239  	for _, p := range ft.In {
   240  		p.Type.addImports(im)
   241  	}
   242  	if ft.Variadic != nil {
   243  		ft.Variadic.Type.addImports(im)
   244  	}
   245  	for _, p := range ft.Out {
   246  		p.Type.addImports(im)
   247  	}
   248  }
   249  
   250  // MapType is a map type.
   251  type MapType struct {
   252  	Key, Value Type
   253  }
   254  
   255  func (mt *MapType) String(pm map[string]string, pkgOverride string) string {
   256  	return "map[" + mt.Key.String(pm, pkgOverride) + "]" + mt.Value.String(pm, pkgOverride)
   257  }
   258  
   259  func (mt *MapType) addImports(im map[string]bool) {
   260  	mt.Key.addImports(im)
   261  	mt.Value.addImports(im)
   262  }
   263  
   264  // NamedType is an exported type in a package.
   265  type NamedType struct {
   266  	Package    string // may be empty
   267  	Type       string
   268  	TypeParams *TypeParametersType
   269  }
   270  
   271  func (nt *NamedType) String(pm map[string]string, pkgOverride string) string {
   272  	if pkgOverride == nt.Package {
   273  		return nt.Type + nt.TypeParams.String(pm, pkgOverride)
   274  	}
   275  	prefix := pm[nt.Package]
   276  	if prefix != "" {
   277  		return prefix + "." + nt.Type + nt.TypeParams.String(pm, pkgOverride)
   278  	}
   279  
   280  	return nt.Type + nt.TypeParams.String(pm, pkgOverride)
   281  }
   282  
   283  func (nt *NamedType) addImports(im map[string]bool) {
   284  	if nt.Package != "" {
   285  		im[nt.Package] = true
   286  	}
   287  	nt.TypeParams.addImports(im)
   288  }
   289  
   290  // PointerType is a pointer to another type.
   291  type PointerType struct {
   292  	Type Type
   293  }
   294  
   295  func (pt *PointerType) String(pm map[string]string, pkgOverride string) string {
   296  	return "*" + pt.Type.String(pm, pkgOverride)
   297  }
   298  func (pt *PointerType) addImports(im map[string]bool) { pt.Type.addImports(im) }
   299  
   300  // PredeclaredType is a predeclared type such as "int".
   301  type PredeclaredType string
   302  
   303  func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) }
   304  func (pt PredeclaredType) addImports(map[string]bool)              {}
   305  
   306  // TypeParametersType contains type paramters for a NamedType.
   307  type TypeParametersType struct {
   308  	TypeParameters []Type
   309  }
   310  
   311  func (tp *TypeParametersType) String(pm map[string]string, pkgOverride string) string {
   312  	if tp == nil || len(tp.TypeParameters) == 0 {
   313  		return ""
   314  	}
   315  	var sb strings.Builder
   316  	sb.WriteString("[")
   317  	for i, v := range tp.TypeParameters {
   318  		if i != 0 {
   319  			sb.WriteString(", ")
   320  		}
   321  		sb.WriteString(v.String(pm, pkgOverride))
   322  	}
   323  	sb.WriteString("]")
   324  	return sb.String()
   325  }
   326  
   327  func (tp *TypeParametersType) addImports(im map[string]bool) {
   328  	if tp == nil {
   329  		return
   330  	}
   331  	for _, v := range tp.TypeParameters {
   332  		v.addImports(im)
   333  	}
   334  }
   335  
   336  // The following code is intended to be called by the program generated by ../reflect.go.
   337  
   338  // InterfaceFromInterfaceType returns a pointer to an interface for the
   339  // given reflection interface type.
   340  func InterfaceFromInterfaceType(it reflect.Type) (*Interface, error) {
   341  	if it.Kind() != reflect.Interface {
   342  		return nil, fmt.Errorf("%v is not an interface", it)
   343  	}
   344  	intf := &Interface{}
   345  
   346  	for i := 0; i < it.NumMethod(); i++ {
   347  		mt := it.Method(i)
   348  		// TODO: need to skip unexported methods? or just raise an error?
   349  		m := &Method{
   350  			Name: mt.Name,
   351  		}
   352  
   353  		var err error
   354  		m.In, m.Variadic, m.Out, err = funcArgsFromType(mt.Type)
   355  		if err != nil {
   356  			return nil, err
   357  		}
   358  
   359  		intf.AddMethod(m)
   360  	}
   361  
   362  	return intf, nil
   363  }
   364  
   365  // t's Kind must be a reflect.Func.
   366  func funcArgsFromType(t reflect.Type) (in []*Parameter, variadic *Parameter, out []*Parameter, err error) {
   367  	nin := t.NumIn()
   368  	if t.IsVariadic() {
   369  		nin--
   370  	}
   371  	var p *Parameter
   372  	for i := 0; i < nin; i++ {
   373  		p, err = parameterFromType(t.In(i))
   374  		if err != nil {
   375  			return
   376  		}
   377  		in = append(in, p)
   378  	}
   379  	if t.IsVariadic() {
   380  		p, err = parameterFromType(t.In(nin).Elem())
   381  		if err != nil {
   382  			return
   383  		}
   384  		variadic = p
   385  	}
   386  	for i := 0; i < t.NumOut(); i++ {
   387  		p, err = parameterFromType(t.Out(i))
   388  		if err != nil {
   389  			return
   390  		}
   391  		out = append(out, p)
   392  	}
   393  	return
   394  }
   395  
   396  func parameterFromType(t reflect.Type) (*Parameter, error) {
   397  	tt, err := typeFromType(t)
   398  	if err != nil {
   399  		return nil, err
   400  	}
   401  	return &Parameter{Type: tt}, nil
   402  }
   403  
   404  var errorType = reflect.TypeOf((*error)(nil)).Elem()
   405  
   406  var byteType = reflect.TypeOf(byte(0))
   407  
   408  func typeFromType(t reflect.Type) (Type, error) {
   409  	// Hack workaround for https://golang.org/issue/3853.
   410  	// This explicit check should not be necessary.
   411  	if t == byteType {
   412  		return PredeclaredType("byte"), nil
   413  	}
   414  
   415  	if imp := t.PkgPath(); imp != "" {
   416  		return &NamedType{
   417  			Package: impPath(imp),
   418  			Type:    t.Name(),
   419  		}, nil
   420  	}
   421  
   422  	// only unnamed or predeclared types after here
   423  
   424  	// Lots of types have element types. Let's do the parsing and error checking for all of them.
   425  	var elemType Type
   426  	switch t.Kind() {
   427  	case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
   428  		var err error
   429  		elemType, err = typeFromType(t.Elem())
   430  		if err != nil {
   431  			return nil, err
   432  		}
   433  	}
   434  
   435  	switch t.Kind() {
   436  	case reflect.Array:
   437  		return &ArrayType{
   438  			Len:  t.Len(),
   439  			Type: elemType,
   440  		}, nil
   441  	case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   442  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
   443  		reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.String:
   444  		return PredeclaredType(t.Kind().String()), nil
   445  	case reflect.Chan:
   446  		var dir ChanDir
   447  		switch t.ChanDir() {
   448  		case reflect.RecvDir:
   449  			dir = RecvDir
   450  		case reflect.SendDir:
   451  			dir = SendDir
   452  		}
   453  		return &ChanType{
   454  			Dir:  dir,
   455  			Type: elemType,
   456  		}, nil
   457  	case reflect.Func:
   458  		in, variadic, out, err := funcArgsFromType(t)
   459  		if err != nil {
   460  			return nil, err
   461  		}
   462  		return &FuncType{
   463  			In:       in,
   464  			Out:      out,
   465  			Variadic: variadic,
   466  		}, nil
   467  	case reflect.Interface:
   468  		// Two special interfaces.
   469  		if t.NumMethod() == 0 {
   470  			return PredeclaredType("interface{}"), nil
   471  		}
   472  		if t == errorType {
   473  			return PredeclaredType("error"), nil
   474  		}
   475  	case reflect.Map:
   476  		kt, err := typeFromType(t.Key())
   477  		if err != nil {
   478  			return nil, err
   479  		}
   480  		return &MapType{
   481  			Key:   kt,
   482  			Value: elemType,
   483  		}, nil
   484  	case reflect.Ptr:
   485  		return &PointerType{
   486  			Type: elemType,
   487  		}, nil
   488  	case reflect.Slice:
   489  		return &ArrayType{
   490  			Len:  -1,
   491  			Type: elemType,
   492  		}, nil
   493  	case reflect.Struct:
   494  		if t.NumField() == 0 {
   495  			return PredeclaredType("struct{}"), nil
   496  		}
   497  	}
   498  
   499  	// TODO: Struct, UnsafePointer
   500  	return nil, fmt.Errorf("can't yet turn %v (%v) into a model.Type", t, t.Kind())
   501  }
   502  
   503  // impPath sanitizes the package path returned by `PkgPath` method of a reflect Type so that
   504  // it is importable. PkgPath might return a path that includes "vendor". These paths do not
   505  // compile, so we need to remove everything up to and including "/vendor/".
   506  // See https://github.com/golang/go/issues/12019.
   507  func impPath(imp string) string {
   508  	if strings.HasPrefix(imp, "vendor/") {
   509  		imp = "/" + imp
   510  	}
   511  	if i := strings.LastIndex(imp, "/vendor/"); i != -1 {
   512  		imp = imp[i+len("/vendor/"):]
   513  	}
   514  	return imp
   515  }
   516  
   517  // ErrorInterface represent built-in error interface.
   518  var ErrorInterface = Interface{
   519  	Name: "error",
   520  	Methods: []*Method{
   521  		{
   522  			Name: "Error",
   523  			Out: []*Parameter{
   524  				{
   525  					Name: "",
   526  					Type: PredeclaredType("string"),
   527  				},
   528  			},
   529  		},
   530  	},
   531  }
   532  

View as plain text