...

Source file src/github.com/gogo/protobuf/protoc-gen-gogo/generator/helper.go

Documentation: github.com/gogo/protobuf/protoc-gen-gogo/generator

     1  // Protocol Buffers for Go with Gadgets
     2  //
     3  // Copyright (c) 2013, The GoGo Authors. All rights reserved.
     4  // http://github.com/gogo/protobuf
     5  //
     6  // Redistribution and use in source and binary forms, with or without
     7  // modification, are permitted provided that the following conditions are
     8  // met:
     9  //
    10  //     * Redistributions of source code must retain the above copyright
    11  // notice, this list of conditions and the following disclaimer.
    12  //     * Redistributions in binary form must reproduce the above
    13  // copyright notice, this list of conditions and the following disclaimer
    14  // in the documentation and/or other materials provided with the
    15  // distribution.
    16  //
    17  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    18  // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    19  // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    20  // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
    21  // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    22  // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    23  // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
    24  // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
    25  // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    26  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    27  // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    28  
    29  package generator
    30  
    31  import (
    32  	"bytes"
    33  	"go/parser"
    34  	"go/printer"
    35  	"go/token"
    36  	"path"
    37  	"strings"
    38  
    39  	"github.com/gogo/protobuf/gogoproto"
    40  	"github.com/gogo/protobuf/proto"
    41  	descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
    42  	plugin "github.com/gogo/protobuf/protoc-gen-gogo/plugin"
    43  )
    44  
    45  func (d *FileDescriptor) Messages() []*Descriptor {
    46  	return d.desc
    47  }
    48  
    49  func (d *FileDescriptor) Enums() []*EnumDescriptor {
    50  	return d.enum
    51  }
    52  
    53  func (d *Descriptor) IsGroup() bool {
    54  	return d.group
    55  }
    56  
    57  func (g *Generator) IsGroup(field *descriptor.FieldDescriptorProto) bool {
    58  	if d, ok := g.typeNameToObject[field.GetTypeName()].(*Descriptor); ok {
    59  		return d.IsGroup()
    60  	}
    61  	return false
    62  }
    63  
    64  func (g *Generator) TypeNameByObject(typeName string) Object {
    65  	o, ok := g.typeNameToObject[typeName]
    66  	if !ok {
    67  		g.Fail("can't find object with type", typeName)
    68  	}
    69  	return o
    70  }
    71  
    72  func (g *Generator) OneOfTypeName(message *Descriptor, field *descriptor.FieldDescriptorProto) string {
    73  	typeName := message.TypeName()
    74  	ccTypeName := CamelCaseSlice(typeName)
    75  	fieldName := g.GetOneOfFieldName(message, field)
    76  	tname := ccTypeName + "_" + fieldName
    77  	// It is possible for this to collide with a message or enum
    78  	// nested in this message. Check for collisions.
    79  	ok := true
    80  	for _, desc := range message.nested {
    81  		if strings.Join(desc.TypeName(), "_") == tname {
    82  			ok = false
    83  			break
    84  		}
    85  	}
    86  	for _, enum := range message.enums {
    87  		if strings.Join(enum.TypeName(), "_") == tname {
    88  			ok = false
    89  			break
    90  		}
    91  	}
    92  	if !ok {
    93  		tname += "_"
    94  	}
    95  	return tname
    96  }
    97  
    98  type PluginImports interface {
    99  	NewImport(pkg string) Single
   100  	GenerateImports(file *FileDescriptor)
   101  }
   102  
   103  type pluginImports struct {
   104  	generator *Generator
   105  	singles   []Single
   106  }
   107  
   108  func NewPluginImports(generator *Generator) *pluginImports {
   109  	return &pluginImports{generator, make([]Single, 0)}
   110  }
   111  
   112  func (this *pluginImports) NewImport(pkg string) Single {
   113  	imp := newImportedPackage(this.generator.ImportPrefix, pkg)
   114  	this.singles = append(this.singles, imp)
   115  	return imp
   116  }
   117  
   118  func (this *pluginImports) GenerateImports(file *FileDescriptor) {
   119  	for _, s := range this.singles {
   120  		if s.IsUsed() {
   121  			this.generator.PrintImport(GoPackageName(s.Name()), GoImportPath(s.Location()))
   122  		}
   123  	}
   124  }
   125  
   126  type Single interface {
   127  	Use() string
   128  	IsUsed() bool
   129  	Name() string
   130  	Location() string
   131  }
   132  
   133  type importedPackage struct {
   134  	used         bool
   135  	pkg          string
   136  	name         string
   137  	importPrefix string
   138  }
   139  
   140  func newImportedPackage(importPrefix string, pkg string) *importedPackage {
   141  	return &importedPackage{
   142  		pkg:          pkg,
   143  		importPrefix: importPrefix,
   144  	}
   145  }
   146  
   147  func (this *importedPackage) Use() string {
   148  	if !this.used {
   149  		this.name = string(cleanPackageName(this.pkg))
   150  		this.used = true
   151  	}
   152  	return this.name
   153  }
   154  
   155  func (this *importedPackage) IsUsed() bool {
   156  	return this.used
   157  }
   158  
   159  func (this *importedPackage) Name() string {
   160  	return this.name
   161  }
   162  
   163  func (this *importedPackage) Location() string {
   164  	return this.importPrefix + this.pkg
   165  }
   166  
   167  func (g *Generator) GetFieldName(message *Descriptor, field *descriptor.FieldDescriptorProto) string {
   168  	goTyp, _ := g.GoType(message, field)
   169  	fieldname := CamelCase(*field.Name)
   170  	if gogoproto.IsCustomName(field) {
   171  		fieldname = gogoproto.GetCustomName(field)
   172  	}
   173  	if gogoproto.IsEmbed(field) {
   174  		fieldname = EmbedFieldName(goTyp)
   175  	}
   176  	if field.OneofIndex != nil {
   177  		fieldname = message.OneofDecl[int(*field.OneofIndex)].GetName()
   178  		fieldname = CamelCase(fieldname)
   179  	}
   180  	for _, f := range methodNames {
   181  		if f == fieldname {
   182  			return fieldname + "_"
   183  		}
   184  	}
   185  	if !gogoproto.IsProtoSizer(message.file.FileDescriptorProto, message.DescriptorProto) {
   186  		if fieldname == "Size" {
   187  			return fieldname + "_"
   188  		}
   189  	}
   190  	return fieldname
   191  }
   192  
   193  func (g *Generator) GetOneOfFieldName(message *Descriptor, field *descriptor.FieldDescriptorProto) string {
   194  	goTyp, _ := g.GoType(message, field)
   195  	fieldname := CamelCase(*field.Name)
   196  	if gogoproto.IsCustomName(field) {
   197  		fieldname = gogoproto.GetCustomName(field)
   198  	}
   199  	if gogoproto.IsEmbed(field) {
   200  		fieldname = EmbedFieldName(goTyp)
   201  	}
   202  	for _, f := range methodNames {
   203  		if f == fieldname {
   204  			return fieldname + "_"
   205  		}
   206  	}
   207  	if !gogoproto.IsProtoSizer(message.file.FileDescriptorProto, message.DescriptorProto) {
   208  		if fieldname == "Size" {
   209  			return fieldname + "_"
   210  		}
   211  	}
   212  	return fieldname
   213  }
   214  
   215  func (g *Generator) IsMap(field *descriptor.FieldDescriptorProto) bool {
   216  	if !field.IsMessage() {
   217  		return false
   218  	}
   219  	byName := g.ObjectNamed(field.GetTypeName())
   220  	desc, ok := byName.(*Descriptor)
   221  	if byName == nil || !ok || !desc.GetOptions().GetMapEntry() {
   222  		return false
   223  	}
   224  	return true
   225  }
   226  
   227  func (g *Generator) GetMapKeyField(field, keyField *descriptor.FieldDescriptorProto) *descriptor.FieldDescriptorProto {
   228  	if !gogoproto.IsCastKey(field) {
   229  		return keyField
   230  	}
   231  	keyField = proto.Clone(keyField).(*descriptor.FieldDescriptorProto)
   232  	if keyField.Options == nil {
   233  		keyField.Options = &descriptor.FieldOptions{}
   234  	}
   235  	keyType := gogoproto.GetCastKey(field)
   236  	if err := proto.SetExtension(keyField.Options, gogoproto.E_Casttype, &keyType); err != nil {
   237  		g.Fail(err.Error())
   238  	}
   239  	return keyField
   240  }
   241  
   242  func (g *Generator) GetMapValueField(field, valField *descriptor.FieldDescriptorProto) *descriptor.FieldDescriptorProto {
   243  	if gogoproto.IsCustomType(field) && gogoproto.IsCastValue(field) {
   244  		g.Fail("cannot have a customtype and casttype: ", field.String())
   245  	}
   246  	valField = proto.Clone(valField).(*descriptor.FieldDescriptorProto)
   247  	if valField.Options == nil {
   248  		valField.Options = &descriptor.FieldOptions{}
   249  	}
   250  
   251  	stdtime := gogoproto.IsStdTime(field)
   252  	if stdtime {
   253  		if err := proto.SetExtension(valField.Options, gogoproto.E_Stdtime, &stdtime); err != nil {
   254  			g.Fail(err.Error())
   255  		}
   256  	}
   257  
   258  	stddur := gogoproto.IsStdDuration(field)
   259  	if stddur {
   260  		if err := proto.SetExtension(valField.Options, gogoproto.E_Stdduration, &stddur); err != nil {
   261  			g.Fail(err.Error())
   262  		}
   263  	}
   264  
   265  	wktptr := gogoproto.IsWktPtr(field)
   266  	if wktptr {
   267  		if err := proto.SetExtension(valField.Options, gogoproto.E_Wktpointer, &wktptr); err != nil {
   268  			g.Fail(err.Error())
   269  		}
   270  	}
   271  
   272  	if valType := gogoproto.GetCastValue(field); len(valType) > 0 {
   273  		if err := proto.SetExtension(valField.Options, gogoproto.E_Casttype, &valType); err != nil {
   274  			g.Fail(err.Error())
   275  		}
   276  	}
   277  	if valType := gogoproto.GetCustomType(field); len(valType) > 0 {
   278  		if err := proto.SetExtension(valField.Options, gogoproto.E_Customtype, &valType); err != nil {
   279  			g.Fail(err.Error())
   280  		}
   281  	}
   282  
   283  	nullable := gogoproto.IsNullable(field)
   284  	if err := proto.SetExtension(valField.Options, gogoproto.E_Nullable, &nullable); err != nil {
   285  		g.Fail(err.Error())
   286  	}
   287  	return valField
   288  }
   289  
   290  // GoMapValueTypes returns the map value Go type and the alias map value Go type (for casting), taking into
   291  // account whether the map is nullable or the value is a message.
   292  func GoMapValueTypes(mapField, valueField *descriptor.FieldDescriptorProto, goValueType, goValueAliasType string) (nullable bool, outGoType string, outGoAliasType string) {
   293  	nullable = gogoproto.IsNullable(mapField) && (valueField.IsMessage() || gogoproto.IsCustomType(mapField))
   294  	if nullable {
   295  		// ensure the non-aliased Go value type is a pointer for consistency
   296  		if strings.HasPrefix(goValueType, "*") {
   297  			outGoType = goValueType
   298  		} else {
   299  			outGoType = "*" + goValueType
   300  		}
   301  		outGoAliasType = goValueAliasType
   302  	} else {
   303  		outGoType = strings.Replace(goValueType, "*", "", 1)
   304  		outGoAliasType = strings.Replace(goValueAliasType, "*", "", 1)
   305  	}
   306  	return
   307  }
   308  
   309  func GoTypeToName(goTyp string) string {
   310  	return strings.Replace(strings.Replace(goTyp, "*", "", -1), "[]", "", -1)
   311  }
   312  
   313  func EmbedFieldName(goTyp string) string {
   314  	goTyp = GoTypeToName(goTyp)
   315  	goTyps := strings.Split(goTyp, ".")
   316  	if len(goTyps) == 1 {
   317  		return goTyp
   318  	}
   319  	if len(goTyps) == 2 {
   320  		return goTyps[1]
   321  	}
   322  	panic("unreachable")
   323  }
   324  
   325  func (g *Generator) GeneratePlugin(p Plugin) {
   326  	plugins = []Plugin{p}
   327  	p.Init(g)
   328  	// Generate the output. The generator runs for every file, even the files
   329  	// that we don't generate output for, so that we can collate the full list
   330  	// of exported symbols to support public imports.
   331  	genFileMap := make(map[*FileDescriptor]bool, len(g.genFiles))
   332  	for _, file := range g.genFiles {
   333  		genFileMap[file] = true
   334  	}
   335  	for _, file := range g.allFiles {
   336  		g.Reset()
   337  		g.writeOutput = genFileMap[file]
   338  		g.generatePlugin(file, p)
   339  		if !g.writeOutput {
   340  			continue
   341  		}
   342  		g.Response.File = append(g.Response.File, &plugin.CodeGeneratorResponse_File{
   343  			Name:    proto.String(file.goFileName(g.pathType)),
   344  			Content: proto.String(g.String()),
   345  		})
   346  	}
   347  }
   348  
   349  func (g *Generator) SetFile(filename string) {
   350  	g.file = g.fileByName(filename)
   351  }
   352  
   353  func (g *Generator) generatePlugin(file *FileDescriptor, p Plugin) {
   354  	g.writtenImports = make(map[string]bool)
   355  	g.usedPackages = make(map[GoImportPath]bool)
   356  	g.packageNames = make(map[GoImportPath]GoPackageName)
   357  	g.usedPackageNames = make(map[GoPackageName]bool)
   358  	g.addedImports = make(map[GoImportPath]bool)
   359  	g.file = file
   360  
   361  	// Run the plugins before the imports so we know which imports are necessary.
   362  	p.Generate(file)
   363  
   364  	// Generate header and imports last, though they appear first in the output.
   365  	rem := g.Buffer
   366  	g.Buffer = new(bytes.Buffer)
   367  	g.generateHeader()
   368  	// p.GenerateImports(g.file)
   369  	g.generateImports()
   370  	if !g.writeOutput {
   371  		return
   372  	}
   373  	g.Write(rem.Bytes())
   374  
   375  	// Reformat generated code.
   376  	contents := string(g.Buffer.Bytes())
   377  	fset := token.NewFileSet()
   378  	ast, err := parser.ParseFile(fset, "", g, parser.ParseComments)
   379  	if err != nil {
   380  		g.Fail("bad Go source code was generated:", contents, err.Error())
   381  		return
   382  	}
   383  	g.Reset()
   384  	err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(g, fset, ast)
   385  	if err != nil {
   386  		g.Fail("generated Go source code could not be reformatted:", err.Error())
   387  	}
   388  }
   389  
   390  func GetCustomType(field *descriptor.FieldDescriptorProto) (packageName string, typ string, err error) {
   391  	return getCustomType(field)
   392  }
   393  
   394  func getCustomType(field *descriptor.FieldDescriptorProto) (packageName string, typ string, err error) {
   395  	if field.Options != nil {
   396  		var v interface{}
   397  		v, err = proto.GetExtension(field.Options, gogoproto.E_Customtype)
   398  		if err == nil && v.(*string) != nil {
   399  			ctype := *(v.(*string))
   400  			packageName, typ = splitCPackageType(ctype)
   401  			return packageName, typ, nil
   402  		}
   403  	}
   404  	return "", "", err
   405  }
   406  
   407  func splitCPackageType(ctype string) (packageName string, typ string) {
   408  	ss := strings.Split(ctype, ".")
   409  	if len(ss) == 1 {
   410  		return "", ctype
   411  	}
   412  	packageName = strings.Join(ss[0:len(ss)-1], ".")
   413  	typeName := ss[len(ss)-1]
   414  	importStr := strings.Map(badToUnderscore, packageName)
   415  	typ = importStr + "." + typeName
   416  	return packageName, typ
   417  }
   418  
   419  func getCastType(field *descriptor.FieldDescriptorProto) (packageName string, typ string, err error) {
   420  	if field.Options != nil {
   421  		var v interface{}
   422  		v, err = proto.GetExtension(field.Options, gogoproto.E_Casttype)
   423  		if err == nil && v.(*string) != nil {
   424  			ctype := *(v.(*string))
   425  			packageName, typ = splitCPackageType(ctype)
   426  			return packageName, typ, nil
   427  		}
   428  	}
   429  	return "", "", err
   430  }
   431  
   432  func FileName(file *FileDescriptor) string {
   433  	fname := path.Base(file.FileDescriptorProto.GetName())
   434  	fname = strings.Replace(fname, ".proto", "", -1)
   435  	fname = strings.Replace(fname, "-", "_", -1)
   436  	fname = strings.Replace(fname, ".", "_", -1)
   437  	return CamelCase(fname)
   438  }
   439  
   440  func (g *Generator) AllFiles() *descriptor.FileDescriptorSet {
   441  	set := &descriptor.FileDescriptorSet{}
   442  	set.File = make([]*descriptor.FileDescriptorProto, len(g.allFiles))
   443  	for i := range g.allFiles {
   444  		set.File[i] = g.allFiles[i].FileDescriptorProto
   445  	}
   446  	return set
   447  }
   448  
   449  func (d *Descriptor) Path() string {
   450  	return d.path
   451  }
   452  
   453  func (g *Generator) useTypes() string {
   454  	pkg := strings.Map(badToUnderscore, "github.com/gogo/protobuf/types")
   455  	g.customImports = append(g.customImports, "github.com/gogo/protobuf/types")
   456  	return pkg
   457  }
   458  
   459  func (d *FileDescriptor) GoPackageName() string {
   460  	return string(d.packageName)
   461  }
   462  

View as plain text