...

Source file src/github.com/gogo/protobuf/plugin/populate/populate.go

Documentation: github.com/gogo/protobuf/plugin/populate

     1  // Protocol Buffers for Go with Gadgets
     2  //
     3  // Copyright (c) 2013, The GoGo Authors. All rights reserved.
     4  // http://github.com/gogo/protobuf
     5  //
     6  // Redistribution and use in source and binary forms, with or without
     7  // modification, are permitted provided that the following conditions are
     8  // met:
     9  //
    10  //     * Redistributions of source code must retain the above copyright
    11  // notice, this list of conditions and the following disclaimer.
    12  //     * Redistributions in binary form must reproduce the above
    13  // copyright notice, this list of conditions and the following disclaimer
    14  // in the documentation and/or other materials provided with the
    15  // distribution.
    16  //
    17  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    18  // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    19  // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    20  // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
    21  // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    22  // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    23  // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
    24  // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
    25  // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    26  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    27  // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    28  
    29  /*
    30  The populate plugin generates a NewPopulated function.
    31  This function returns a newly populated structure.
    32  
    33  It is enabled by the following extensions:
    34  
    35    - populate
    36    - populate_all
    37  
    38  Let us look at:
    39  
    40    github.com/gogo/protobuf/test/example/example.proto
    41  
    42  Btw all the output can be seen at:
    43  
    44    github.com/gogo/protobuf/test/example/*
    45  
    46  The following message:
    47  
    48    option (gogoproto.populate_all) = true;
    49  
    50    message B {
    51  	optional A A = 1 [(gogoproto.nullable) = false, (gogoproto.embed) = true];
    52  	repeated bytes G = 2 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uint128", (gogoproto.nullable) = false];
    53    }
    54  
    55  given to the populate plugin, will generate code the following code:
    56  
    57    func NewPopulatedB(r randyExample, easy bool) *B {
    58  	this := &B{}
    59  	v2 := NewPopulatedA(r, easy)
    60  	this.A = *v2
    61  	if r.Intn(10) != 0 {
    62  		v3 := r.Intn(10)
    63  		this.G = make([]github_com_gogo_protobuf_test_custom.Uint128, v3)
    64  		for i := 0; i < v3; i++ {
    65  			v4 := github_com_gogo_protobuf_test_custom.NewPopulatedUint128(r)
    66  			this.G[i] = *v4
    67  		}
    68  	}
    69  	if !easy && r.Intn(10) != 0 {
    70  		this.XXX_unrecognized = randUnrecognizedExample(r, 3)
    71  	}
    72  	return this
    73    }
    74  
    75  The idea that is useful for testing.
    76  Most of the other plugins' generated test code uses it.
    77  You will still be able to use the generated test code of other packages
    78  if you turn off the popluate plugin and write your own custom NewPopulated function.
    79  
    80  If the easy flag is not set the XXX_unrecognized and XXX_extensions fields are also populated.
    81  These have caused problems with JSON marshalling and unmarshalling tests.
    82  
    83  */
    84  package populate
    85  
    86  import (
    87  	"fmt"
    88  	"math"
    89  	"strconv"
    90  	"strings"
    91  
    92  	"github.com/gogo/protobuf/gogoproto"
    93  	"github.com/gogo/protobuf/proto"
    94  	descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
    95  	"github.com/gogo/protobuf/protoc-gen-gogo/generator"
    96  	"github.com/gogo/protobuf/vanity"
    97  )
    98  
    99  type VarGen interface {
   100  	Next() string
   101  	Current() string
   102  }
   103  
   104  type varGen struct {
   105  	index int64
   106  }
   107  
   108  func NewVarGen() VarGen {
   109  	return &varGen{0}
   110  }
   111  
   112  func (this *varGen) Next() string {
   113  	this.index++
   114  	return fmt.Sprintf("v%d", this.index)
   115  }
   116  
   117  func (this *varGen) Current() string {
   118  	return fmt.Sprintf("v%d", this.index)
   119  }
   120  
   121  type plugin struct {
   122  	*generator.Generator
   123  	generator.PluginImports
   124  	varGen     VarGen
   125  	atleastOne bool
   126  	localName  string
   127  	typesPkg   generator.Single
   128  }
   129  
   130  func NewPlugin() *plugin {
   131  	return &plugin{}
   132  }
   133  
   134  func (p *plugin) Name() string {
   135  	return "populate"
   136  }
   137  
   138  func (p *plugin) Init(g *generator.Generator) {
   139  	p.Generator = g
   140  }
   141  
   142  func value(typeName string, fieldType descriptor.FieldDescriptorProto_Type) string {
   143  	switch fieldType {
   144  	case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
   145  		return typeName + "(r.Float64())"
   146  	case descriptor.FieldDescriptorProto_TYPE_FLOAT:
   147  		return typeName + "(r.Float32())"
   148  	case descriptor.FieldDescriptorProto_TYPE_INT64,
   149  		descriptor.FieldDescriptorProto_TYPE_SFIXED64,
   150  		descriptor.FieldDescriptorProto_TYPE_SINT64:
   151  		return typeName + "(r.Int63())"
   152  	case descriptor.FieldDescriptorProto_TYPE_UINT64,
   153  		descriptor.FieldDescriptorProto_TYPE_FIXED64:
   154  		return typeName + "(uint64(r.Uint32()))"
   155  	case descriptor.FieldDescriptorProto_TYPE_INT32,
   156  		descriptor.FieldDescriptorProto_TYPE_SINT32,
   157  		descriptor.FieldDescriptorProto_TYPE_SFIXED32,
   158  		descriptor.FieldDescriptorProto_TYPE_ENUM:
   159  		return typeName + "(r.Int31())"
   160  	case descriptor.FieldDescriptorProto_TYPE_UINT32,
   161  		descriptor.FieldDescriptorProto_TYPE_FIXED32:
   162  		return typeName + "(r.Uint32())"
   163  	case descriptor.FieldDescriptorProto_TYPE_BOOL:
   164  		return typeName + `(bool(r.Intn(2) == 0))`
   165  	case descriptor.FieldDescriptorProto_TYPE_STRING,
   166  		descriptor.FieldDescriptorProto_TYPE_GROUP,
   167  		descriptor.FieldDescriptorProto_TYPE_MESSAGE,
   168  		descriptor.FieldDescriptorProto_TYPE_BYTES:
   169  	}
   170  	panic(fmt.Errorf("unexpected type %v", typeName))
   171  }
   172  
   173  func negative(fieldType descriptor.FieldDescriptorProto_Type) bool {
   174  	switch fieldType {
   175  	case descriptor.FieldDescriptorProto_TYPE_UINT64,
   176  		descriptor.FieldDescriptorProto_TYPE_FIXED64,
   177  		descriptor.FieldDescriptorProto_TYPE_UINT32,
   178  		descriptor.FieldDescriptorProto_TYPE_FIXED32,
   179  		descriptor.FieldDescriptorProto_TYPE_BOOL:
   180  		return false
   181  	}
   182  	return true
   183  }
   184  
   185  func (p *plugin) getFuncName(goTypName string, field *descriptor.FieldDescriptorProto) string {
   186  	funcName := "NewPopulated" + goTypName
   187  	goTypNames := strings.Split(goTypName, ".")
   188  	if len(goTypNames) == 2 {
   189  		funcName = goTypNames[0] + ".NewPopulated" + goTypNames[1]
   190  	} else if len(goTypNames) != 1 {
   191  		panic(fmt.Errorf("unreachable: too many dots in %v", goTypName))
   192  	}
   193  	if field != nil {
   194  		switch {
   195  		case gogoproto.IsStdTime(field):
   196  			funcName = p.typesPkg.Use() + ".NewPopulatedStdTime"
   197  		case gogoproto.IsStdDuration(field):
   198  			funcName = p.typesPkg.Use() + ".NewPopulatedStdDuration"
   199  		case gogoproto.IsStdDouble(field):
   200  			funcName = p.typesPkg.Use() + ".NewPopulatedStdDouble"
   201  		case gogoproto.IsStdFloat(field):
   202  			funcName = p.typesPkg.Use() + ".NewPopulatedStdFloat"
   203  		case gogoproto.IsStdInt64(field):
   204  			funcName = p.typesPkg.Use() + ".NewPopulatedStdInt64"
   205  		case gogoproto.IsStdUInt64(field):
   206  			funcName = p.typesPkg.Use() + ".NewPopulatedStdUInt64"
   207  		case gogoproto.IsStdInt32(field):
   208  			funcName = p.typesPkg.Use() + ".NewPopulatedStdInt32"
   209  		case gogoproto.IsStdUInt32(field):
   210  			funcName = p.typesPkg.Use() + ".NewPopulatedStdUInt32"
   211  		case gogoproto.IsStdBool(field):
   212  			funcName = p.typesPkg.Use() + ".NewPopulatedStdBool"
   213  		case gogoproto.IsStdString(field):
   214  			funcName = p.typesPkg.Use() + ".NewPopulatedStdString"
   215  		case gogoproto.IsStdBytes(field):
   216  			funcName = p.typesPkg.Use() + ".NewPopulatedStdBytes"
   217  		}
   218  	}
   219  	return funcName
   220  }
   221  
   222  func (p *plugin) getFuncCall(goTypName string, field *descriptor.FieldDescriptorProto) string {
   223  	funcName := p.getFuncName(goTypName, field)
   224  	funcCall := funcName + "(r, easy)"
   225  	return funcCall
   226  }
   227  
   228  func (p *plugin) getCustomFuncCall(goTypName string) string {
   229  	funcName := p.getFuncName(goTypName, nil)
   230  	funcCall := funcName + "(r)"
   231  	return funcCall
   232  }
   233  
   234  func (p *plugin) getEnumVal(field *descriptor.FieldDescriptorProto, goTyp string) string {
   235  	enum := p.ObjectNamed(field.GetTypeName()).(*generator.EnumDescriptor)
   236  	l := len(enum.Value)
   237  	values := make([]string, l)
   238  	for i := range enum.Value {
   239  		values[i] = strconv.Itoa(int(*enum.Value[i].Number))
   240  	}
   241  	arr := "[]int32{" + strings.Join(values, ",") + "}"
   242  	val := strings.Join([]string{generator.GoTypeToName(goTyp), `(`, arr, `[r.Intn(`, fmt.Sprintf("%d", l), `)])`}, "")
   243  	return val
   244  }
   245  
   246  func (p *plugin) GenerateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) {
   247  	proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
   248  	goTyp, _ := p.GoType(message, field)
   249  	fieldname := p.GetOneOfFieldName(message, field)
   250  	goTypName := generator.GoTypeToName(goTyp)
   251  	if p.IsMap(field) {
   252  		m := p.GoMapType(nil, field)
   253  		keygoTyp, _ := p.GoType(nil, m.KeyField)
   254  		keygoTyp = strings.Replace(keygoTyp, "*", "", 1)
   255  		keygoAliasTyp, _ := p.GoType(nil, m.KeyAliasField)
   256  		keygoAliasTyp = strings.Replace(keygoAliasTyp, "*", "", 1)
   257  
   258  		valuegoTyp, _ := p.GoType(nil, m.ValueField)
   259  		valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField)
   260  		keytypName := generator.GoTypeToName(keygoTyp)
   261  		keygoAliasTyp = generator.GoTypeToName(keygoAliasTyp)
   262  		valuetypAliasName := generator.GoTypeToName(valuegoAliasTyp)
   263  
   264  		nullable, valuegoTyp, valuegoAliasTyp := generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp)
   265  
   266  		p.P(p.varGen.Next(), ` := r.Intn(10)`)
   267  		p.P(`this.`, fieldname, ` = make(`, m.GoType, `)`)
   268  		p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`)
   269  		p.In()
   270  		keyval := ""
   271  		if m.KeyField.IsString() {
   272  			keyval = fmt.Sprintf("randString%v(r)", p.localName)
   273  		} else {
   274  			keyval = value(keytypName, m.KeyField.GetType())
   275  		}
   276  		if keygoAliasTyp != keygoTyp {
   277  			keyval = keygoAliasTyp + `(` + keyval + `)`
   278  		}
   279  		if m.ValueField.IsMessage() || p.IsGroup(field) ||
   280  			(m.ValueField.IsBytes() && gogoproto.IsCustomType(field)) {
   281  			s := `this.` + fieldname + `[` + keyval + `] = `
   282  			if gogoproto.IsStdType(field) {
   283  				valuegoTyp = valuegoAliasTyp
   284  			}
   285  			funcCall := p.getCustomFuncCall(goTypName)
   286  			if !gogoproto.IsCustomType(field) {
   287  				goTypName = generator.GoTypeToName(valuegoTyp)
   288  				funcCall = p.getFuncCall(goTypName, m.ValueAliasField)
   289  			}
   290  			if !nullable {
   291  				funcCall = `*` + funcCall
   292  			}
   293  			if valuegoTyp != valuegoAliasTyp {
   294  				funcCall = `(` + valuegoAliasTyp + `)(` + funcCall + `)`
   295  			}
   296  			s += funcCall
   297  			p.P(s)
   298  		} else if m.ValueField.IsEnum() {
   299  			s := `this.` + fieldname + `[` + keyval + `]` + ` = ` + p.getEnumVal(m.ValueField, valuegoTyp)
   300  			p.P(s)
   301  		} else if m.ValueField.IsBytes() {
   302  			count := p.varGen.Next()
   303  			p.P(count, ` := r.Intn(100)`)
   304  			p.P(p.varGen.Next(), ` := `, keyval)
   305  			p.P(`this.`, fieldname, `[`, p.varGen.Current(), `] = make(`, valuegoTyp, `, `, count, `)`)
   306  			p.P(`for i := 0; i < `, count, `; i++ {`)
   307  			p.In()
   308  			p.P(`this.`, fieldname, `[`, p.varGen.Current(), `][i] = byte(r.Intn(256))`)
   309  			p.Out()
   310  			p.P(`}`)
   311  		} else if m.ValueField.IsString() {
   312  			s := `this.` + fieldname + `[` + keyval + `]` + ` = ` + fmt.Sprintf("randString%v(r)", p.localName)
   313  			p.P(s)
   314  		} else {
   315  			p.P(p.varGen.Next(), ` := `, keyval)
   316  			p.P(`this.`, fieldname, `[`, p.varGen.Current(), `] = `, value(valuetypAliasName, m.ValueField.GetType()))
   317  			if negative(m.ValueField.GetType()) {
   318  				p.P(`if r.Intn(2) == 0 {`)
   319  				p.In()
   320  				p.P(`this.`, fieldname, `[`, p.varGen.Current(), `] *= -1`)
   321  				p.Out()
   322  				p.P(`}`)
   323  			}
   324  		}
   325  		p.Out()
   326  		p.P(`}`)
   327  	} else if gogoproto.IsCustomType(field) {
   328  		funcCall := p.getCustomFuncCall(goTypName)
   329  		if field.IsRepeated() {
   330  			p.P(p.varGen.Next(), ` := r.Intn(10)`)
   331  			p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`)
   332  			p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`)
   333  			p.In()
   334  			p.P(p.varGen.Next(), `:= `, funcCall)
   335  			p.P(`this.`, fieldname, `[i] = *`, p.varGen.Current())
   336  			p.Out()
   337  			p.P(`}`)
   338  		} else if gogoproto.IsNullable(field) {
   339  			p.P(`this.`, fieldname, ` = `, funcCall)
   340  		} else {
   341  			p.P(p.varGen.Next(), `:= `, funcCall)
   342  			p.P(`this.`, fieldname, ` = *`, p.varGen.Current())
   343  		}
   344  	} else if field.IsMessage() || p.IsGroup(field) {
   345  		funcCall := p.getFuncCall(goTypName, field)
   346  		if field.IsRepeated() {
   347  			p.P(p.varGen.Next(), ` := r.Intn(5)`)
   348  			p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`)
   349  			p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`)
   350  			p.In()
   351  			if gogoproto.IsNullable(field) {
   352  				p.P(`this.`, fieldname, `[i] = `, funcCall)
   353  			} else {
   354  				p.P(p.varGen.Next(), `:= `, funcCall)
   355  				p.P(`this.`, fieldname, `[i] = *`, p.varGen.Current())
   356  			}
   357  			p.Out()
   358  			p.P(`}`)
   359  		} else {
   360  			if gogoproto.IsNullable(field) {
   361  				p.P(`this.`, fieldname, ` = `, funcCall)
   362  			} else {
   363  				p.P(p.varGen.Next(), `:= `, funcCall)
   364  				p.P(`this.`, fieldname, ` = *`, p.varGen.Current())
   365  			}
   366  		}
   367  	} else {
   368  		if field.IsEnum() {
   369  			val := p.getEnumVal(field, goTyp)
   370  			if field.IsRepeated() {
   371  				p.P(p.varGen.Next(), ` := r.Intn(10)`)
   372  				p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`)
   373  				p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`)
   374  				p.In()
   375  				p.P(`this.`, fieldname, `[i] = `, val)
   376  				p.Out()
   377  				p.P(`}`)
   378  			} else if !gogoproto.IsNullable(field) || proto3 {
   379  				p.P(`this.`, fieldname, ` = `, val)
   380  			} else {
   381  				p.P(p.varGen.Next(), ` := `, val)
   382  				p.P(`this.`, fieldname, ` = &`, p.varGen.Current())
   383  			}
   384  		} else if field.IsBytes() {
   385  			if field.IsRepeated() {
   386  				p.P(p.varGen.Next(), ` := r.Intn(10)`)
   387  				p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`)
   388  				p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`)
   389  				p.In()
   390  				p.P(p.varGen.Next(), ` := r.Intn(100)`)
   391  				p.P(`this.`, fieldname, `[i] = make([]byte,`, p.varGen.Current(), `)`)
   392  				p.P(`for j := 0; j < `, p.varGen.Current(), `; j++ {`)
   393  				p.In()
   394  				p.P(`this.`, fieldname, `[i][j] = byte(r.Intn(256))`)
   395  				p.Out()
   396  				p.P(`}`)
   397  				p.Out()
   398  				p.P(`}`)
   399  			} else {
   400  				p.P(p.varGen.Next(), ` := r.Intn(100)`)
   401  				p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`)
   402  				p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`)
   403  				p.In()
   404  				p.P(`this.`, fieldname, `[i] = byte(r.Intn(256))`)
   405  				p.Out()
   406  				p.P(`}`)
   407  			}
   408  		} else if field.IsString() {
   409  			typName := generator.GoTypeToName(goTyp)
   410  			val := fmt.Sprintf("%s(randString%v(r))", typName, p.localName)
   411  			if field.IsRepeated() {
   412  				p.P(p.varGen.Next(), ` := r.Intn(10)`)
   413  				p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`)
   414  				p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`)
   415  				p.In()
   416  				p.P(`this.`, fieldname, `[i] = `, val)
   417  				p.Out()
   418  				p.P(`}`)
   419  			} else if !gogoproto.IsNullable(field) || proto3 {
   420  				p.P(`this.`, fieldname, ` = `, val)
   421  			} else {
   422  				p.P(p.varGen.Next(), `:= `, val)
   423  				p.P(`this.`, fieldname, ` = &`, p.varGen.Current())
   424  			}
   425  		} else {
   426  			typName := generator.GoTypeToName(goTyp)
   427  			if field.IsRepeated() {
   428  				p.P(p.varGen.Next(), ` := r.Intn(10)`)
   429  				p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`)
   430  				p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`)
   431  				p.In()
   432  				p.P(`this.`, fieldname, `[i] = `, value(typName, field.GetType()))
   433  				if negative(field.GetType()) {
   434  					p.P(`if r.Intn(2) == 0 {`)
   435  					p.In()
   436  					p.P(`this.`, fieldname, `[i] *= -1`)
   437  					p.Out()
   438  					p.P(`}`)
   439  				}
   440  				p.Out()
   441  				p.P(`}`)
   442  			} else if !gogoproto.IsNullable(field) || proto3 {
   443  				p.P(`this.`, fieldname, ` = `, value(typName, field.GetType()))
   444  				if negative(field.GetType()) {
   445  					p.P(`if r.Intn(2) == 0 {`)
   446  					p.In()
   447  					p.P(`this.`, fieldname, ` *= -1`)
   448  					p.Out()
   449  					p.P(`}`)
   450  				}
   451  			} else {
   452  				p.P(p.varGen.Next(), ` := `, value(typName, field.GetType()))
   453  				if negative(field.GetType()) {
   454  					p.P(`if r.Intn(2) == 0 {`)
   455  					p.In()
   456  					p.P(p.varGen.Current(), ` *= -1`)
   457  					p.Out()
   458  					p.P(`}`)
   459  				}
   460  				p.P(`this.`, fieldname, ` = &`, p.varGen.Current())
   461  			}
   462  		}
   463  	}
   464  }
   465  
   466  func (p *plugin) hasLoop(pkg string, field *descriptor.FieldDescriptorProto, visited []*generator.Descriptor, excludes []*generator.Descriptor) *generator.Descriptor {
   467  	if field.IsMessage() || p.IsGroup(field) || p.IsMap(field) {
   468  		var fieldMessage *generator.Descriptor
   469  		if p.IsMap(field) {
   470  			m := p.GoMapType(nil, field)
   471  			if !m.ValueField.IsMessage() {
   472  				return nil
   473  			}
   474  			fieldMessage = p.ObjectNamed(m.ValueField.GetTypeName()).(*generator.Descriptor)
   475  		} else {
   476  			fieldMessage = p.ObjectNamed(field.GetTypeName()).(*generator.Descriptor)
   477  		}
   478  		fieldTypeName := generator.CamelCaseSlice(fieldMessage.TypeName())
   479  		for _, message := range visited {
   480  			messageTypeName := generator.CamelCaseSlice(message.TypeName())
   481  			if fieldTypeName == messageTypeName {
   482  				for _, e := range excludes {
   483  					if fieldTypeName == generator.CamelCaseSlice(e.TypeName()) {
   484  						return nil
   485  					}
   486  				}
   487  				return fieldMessage
   488  			}
   489  		}
   490  
   491  		for _, f := range fieldMessage.Field {
   492  			if strings.HasPrefix(f.GetTypeName(), "."+pkg) {
   493  				visited = append(visited, fieldMessage)
   494  				loopTo := p.hasLoop(pkg, f, visited, excludes)
   495  				if loopTo != nil {
   496  					return loopTo
   497  				}
   498  			}
   499  		}
   500  	}
   501  	return nil
   502  }
   503  
   504  func (p *plugin) loops(pkg string, field *descriptor.FieldDescriptorProto, message *generator.Descriptor) int {
   505  	//fmt.Fprintf(os.Stderr, "loops %v %v\n", field.GetTypeName(), generator.CamelCaseSlice(message.TypeName()))
   506  	excludes := []*generator.Descriptor{}
   507  	loops := 0
   508  	for {
   509  		visited := []*generator.Descriptor{}
   510  		loopTo := p.hasLoop(pkg, field, visited, excludes)
   511  		if loopTo == nil {
   512  			break
   513  		}
   514  		//fmt.Fprintf(os.Stderr, "loopTo %v\n", generator.CamelCaseSlice(loopTo.TypeName()))
   515  		excludes = append(excludes, loopTo)
   516  		loops++
   517  	}
   518  	return loops
   519  }
   520  
   521  func (p *plugin) Generate(file *generator.FileDescriptor) {
   522  	p.atleastOne = false
   523  	p.PluginImports = generator.NewPluginImports(p.Generator)
   524  	p.varGen = NewVarGen()
   525  	proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
   526  	p.typesPkg = p.NewImport("github.com/gogo/protobuf/types")
   527  	p.localName = generator.FileName(file)
   528  	protoPkg := p.NewImport("github.com/gogo/protobuf/proto")
   529  	if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) {
   530  		protoPkg = p.NewImport("github.com/golang/protobuf/proto")
   531  	}
   532  
   533  	for _, message := range file.Messages() {
   534  		if !gogoproto.HasPopulate(file.FileDescriptorProto, message.DescriptorProto) {
   535  			continue
   536  		}
   537  		if message.DescriptorProto.GetOptions().GetMapEntry() {
   538  			continue
   539  		}
   540  		p.atleastOne = true
   541  		ccTypeName := generator.CamelCaseSlice(message.TypeName())
   542  		loopLevels := make([]int, len(message.Field))
   543  		maxLoopLevel := 0
   544  		for i, field := range message.Field {
   545  			loopLevels[i] = p.loops(file.GetPackage(), field, message)
   546  			if loopLevels[i] > maxLoopLevel {
   547  				maxLoopLevel = loopLevels[i]
   548  			}
   549  		}
   550  		ranTotal := 0
   551  		for i := range loopLevels {
   552  			ranTotal += int(math.Pow10(maxLoopLevel - loopLevels[i]))
   553  		}
   554  		p.P(`func NewPopulated`, ccTypeName, `(r randy`, p.localName, `, easy bool) *`, ccTypeName, ` {`)
   555  		p.In()
   556  		p.P(`this := &`, ccTypeName, `{}`)
   557  		if gogoproto.IsUnion(message.File().FileDescriptorProto, message.DescriptorProto) && len(message.Field) > 0 {
   558  			p.P(`fieldNum := r.Intn(`, fmt.Sprintf("%d", ranTotal), `)`)
   559  			p.P(`switch fieldNum {`)
   560  			k := 0
   561  			for i, field := range message.Field {
   562  				is := []string{}
   563  				ran := int(math.Pow10(maxLoopLevel - loopLevels[i]))
   564  				for j := 0; j < ran; j++ {
   565  					is = append(is, fmt.Sprintf("%d", j+k))
   566  				}
   567  				k += ran
   568  				p.P(`case `, strings.Join(is, ","), `:`)
   569  				p.In()
   570  				p.GenerateField(file, message, field)
   571  				p.Out()
   572  			}
   573  			p.P(`}`)
   574  		} else {
   575  			var maxFieldNumber int32
   576  			oneofs := make(map[string]struct{})
   577  			for fieldIndex, field := range message.Field {
   578  				if field.GetNumber() > maxFieldNumber {
   579  					maxFieldNumber = field.GetNumber()
   580  				}
   581  				oneof := field.OneofIndex != nil
   582  				if !oneof {
   583  					if field.IsRequired() || (!gogoproto.IsNullable(field) && !field.IsRepeated()) || (proto3 && !field.IsMessage()) {
   584  						p.GenerateField(file, message, field)
   585  					} else {
   586  						if loopLevels[fieldIndex] > 0 {
   587  							p.P(`if r.Intn(5) == 0 {`)
   588  						} else {
   589  							p.P(`if r.Intn(5) != 0 {`)
   590  						}
   591  						p.In()
   592  						p.GenerateField(file, message, field)
   593  						p.Out()
   594  						p.P(`}`)
   595  					}
   596  				} else {
   597  					fieldname := p.GetFieldName(message, field)
   598  					if _, ok := oneofs[fieldname]; ok {
   599  						continue
   600  					} else {
   601  						oneofs[fieldname] = struct{}{}
   602  					}
   603  					fieldNumbers := []int32{}
   604  					for _, f := range message.Field {
   605  						fname := p.GetFieldName(message, f)
   606  						if fname == fieldname {
   607  							fieldNumbers = append(fieldNumbers, f.GetNumber())
   608  						}
   609  					}
   610  
   611  					p.P(`oneofNumber_`, fieldname, ` := `, fmt.Sprintf("%#v", fieldNumbers), `[r.Intn(`, strconv.Itoa(len(fieldNumbers)), `)]`)
   612  					p.P(`switch oneofNumber_`, fieldname, ` {`)
   613  					for _, f := range message.Field {
   614  						fname := p.GetFieldName(message, f)
   615  						if fname != fieldname {
   616  							continue
   617  						}
   618  						p.P(`case `, strconv.Itoa(int(f.GetNumber())), `:`)
   619  						p.In()
   620  						ccTypeName := p.OneOfTypeName(message, f)
   621  						p.P(`this.`, fname, ` = NewPopulated`, ccTypeName, `(r, easy)`)
   622  						p.Out()
   623  					}
   624  					p.P(`}`)
   625  				}
   626  			}
   627  			if message.DescriptorProto.HasExtension() {
   628  				p.P(`if !easy && r.Intn(10) != 0 {`)
   629  				p.In()
   630  				p.P(`l := r.Intn(5)`)
   631  				p.P(`for i := 0; i < l; i++ {`)
   632  				p.In()
   633  				if len(message.DescriptorProto.GetExtensionRange()) > 1 {
   634  					p.P(`eIndex := r.Intn(`, strconv.Itoa(len(message.DescriptorProto.GetExtensionRange())), `)`)
   635  					p.P(`fieldNumber := 0`)
   636  					p.P(`switch eIndex {`)
   637  					for i, e := range message.DescriptorProto.GetExtensionRange() {
   638  						p.P(`case `, strconv.Itoa(i), `:`)
   639  						p.In()
   640  						p.P(`fieldNumber = r.Intn(`, strconv.Itoa(int(e.GetEnd()-e.GetStart())), `) + `, strconv.Itoa(int(e.GetStart())))
   641  						p.Out()
   642  						if e.GetEnd() > maxFieldNumber {
   643  							maxFieldNumber = e.GetEnd()
   644  						}
   645  					}
   646  					p.P(`}`)
   647  				} else {
   648  					e := message.DescriptorProto.GetExtensionRange()[0]
   649  					p.P(`fieldNumber := r.Intn(`, strconv.Itoa(int(e.GetEnd()-e.GetStart())), `) + `, strconv.Itoa(int(e.GetStart())))
   650  					if e.GetEnd() > maxFieldNumber {
   651  						maxFieldNumber = e.GetEnd()
   652  					}
   653  				}
   654  				p.P(`wire := r.Intn(4)`)
   655  				p.P(`if wire == 3 { wire = 5 }`)
   656  				p.P(`dAtA := randField`, p.localName, `(nil, r, fieldNumber, wire)`)
   657  				p.P(protoPkg.Use(), `.SetRawExtension(this, int32(fieldNumber), dAtA)`)
   658  				p.Out()
   659  				p.P(`}`)
   660  				p.Out()
   661  				p.P(`}`)
   662  			}
   663  
   664  			if maxFieldNumber < (1 << 10) {
   665  				p.P(`if !easy && r.Intn(10) != 0 {`)
   666  				p.In()
   667  				if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) {
   668  					p.P(`this.XXX_unrecognized = randUnrecognized`, p.localName, `(r, `, strconv.Itoa(int(maxFieldNumber+1)), `)`)
   669  				}
   670  				p.Out()
   671  				p.P(`}`)
   672  			}
   673  		}
   674  		p.P(`return this`)
   675  		p.Out()
   676  		p.P(`}`)
   677  		p.P(``)
   678  
   679  		//Generate NewPopulated functions for oneof fields
   680  		m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto)
   681  		for _, f := range m.Field {
   682  			oneof := f.OneofIndex != nil
   683  			if !oneof {
   684  				continue
   685  			}
   686  			ccTypeName := p.OneOfTypeName(message, f)
   687  			p.P(`func NewPopulated`, ccTypeName, `(r randy`, p.localName, `, easy bool) *`, ccTypeName, ` {`)
   688  			p.In()
   689  			p.P(`this := &`, ccTypeName, `{}`)
   690  			vanity.TurnOffNullableForNativeTypes(f)
   691  			p.GenerateField(file, message, f)
   692  			p.P(`return this`)
   693  			p.Out()
   694  			p.P(`}`)
   695  		}
   696  	}
   697  
   698  	if !p.atleastOne {
   699  		return
   700  	}
   701  
   702  	p.P(`type randy`, p.localName, ` interface {`)
   703  	p.In()
   704  	p.P(`Float32() float32`)
   705  	p.P(`Float64() float64`)
   706  	p.P(`Int63() int64`)
   707  	p.P(`Int31() int32`)
   708  	p.P(`Uint32() uint32`)
   709  	p.P(`Intn(n int) int`)
   710  	p.Out()
   711  	p.P(`}`)
   712  
   713  	p.P(`func randUTF8Rune`, p.localName, `(r randy`, p.localName, `) rune {`)
   714  	p.In()
   715  	p.P(`ru := r.Intn(62)`)
   716  	p.P(`if ru < 10 {`)
   717  	p.In()
   718  	p.P(`return rune(ru+48)`)
   719  	p.Out()
   720  	p.P(`} else if ru < 36 {`)
   721  	p.In()
   722  	p.P(`return rune(ru+55)`)
   723  	p.Out()
   724  	p.P(`}`)
   725  	p.P(`return rune(ru+61)`)
   726  	p.Out()
   727  	p.P(`}`)
   728  
   729  	p.P(`func randString`, p.localName, `(r randy`, p.localName, `) string {`)
   730  	p.In()
   731  	p.P(p.varGen.Next(), ` := r.Intn(100)`)
   732  	p.P(`tmps := make([]rune, `, p.varGen.Current(), `)`)
   733  	p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`)
   734  	p.In()
   735  	p.P(`tmps[i] = randUTF8Rune`, p.localName, `(r)`)
   736  	p.Out()
   737  	p.P(`}`)
   738  	p.P(`return string(tmps)`)
   739  	p.Out()
   740  	p.P(`}`)
   741  
   742  	p.P(`func randUnrecognized`, p.localName, `(r randy`, p.localName, `, maxFieldNumber int) (dAtA []byte) {`)
   743  	p.In()
   744  	p.P(`l := r.Intn(5)`)
   745  	p.P(`for i := 0; i < l; i++ {`)
   746  	p.In()
   747  	p.P(`wire := r.Intn(4)`)
   748  	p.P(`if wire == 3 { wire = 5 }`)
   749  	p.P(`fieldNumber := maxFieldNumber + r.Intn(100)`)
   750  	p.P(`dAtA = randField`, p.localName, `(dAtA, r, fieldNumber, wire)`)
   751  	p.Out()
   752  	p.P(`}`)
   753  	p.P(`return dAtA`)
   754  	p.Out()
   755  	p.P(`}`)
   756  
   757  	p.P(`func randField`, p.localName, `(dAtA []byte, r randy`, p.localName, `, fieldNumber int, wire int) []byte {`)
   758  	p.In()
   759  	p.P(`key := uint32(fieldNumber)<<3 | uint32(wire)`)
   760  	p.P(`switch wire {`)
   761  	p.P(`case 0:`)
   762  	p.In()
   763  	p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(key))`)
   764  	p.P(p.varGen.Next(), ` := r.Int63()`)
   765  	p.P(`if r.Intn(2) == 0 {`)
   766  	p.In()
   767  	p.P(p.varGen.Current(), ` *= -1`)
   768  	p.Out()
   769  	p.P(`}`)
   770  	p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(`, p.varGen.Current(), `))`)
   771  	p.Out()
   772  	p.P(`case 1:`)
   773  	p.In()
   774  	p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(key))`)
   775  	p.P(`dAtA = append(dAtA, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)))`)
   776  	p.Out()
   777  	p.P(`case 2:`)
   778  	p.In()
   779  	p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(key))`)
   780  	p.P(`ll := r.Intn(100)`)
   781  	p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(ll))`)
   782  	p.P(`for j := 0; j < ll; j++ {`)
   783  	p.In()
   784  	p.P(`dAtA = append(dAtA, byte(r.Intn(256)))`)
   785  	p.Out()
   786  	p.P(`}`)
   787  	p.Out()
   788  	p.P(`default:`)
   789  	p.In()
   790  	p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(key))`)
   791  	p.P(`dAtA = append(dAtA, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)))`)
   792  	p.Out()
   793  	p.P(`}`)
   794  	p.P(`return dAtA`)
   795  	p.Out()
   796  	p.P(`}`)
   797  
   798  	p.P(`func encodeVarintPopulate`, p.localName, `(dAtA []byte, v uint64) []byte {`)
   799  	p.In()
   800  	p.P(`for v >= 1<<7 {`)
   801  	p.In()
   802  	p.P(`dAtA = append(dAtA, uint8(uint64(v)&0x7f|0x80))`)
   803  	p.P(`v >>= 7`)
   804  	p.Out()
   805  	p.P(`}`)
   806  	p.P(`dAtA = append(dAtA, uint8(v))`)
   807  	p.P(`return dAtA`)
   808  	p.Out()
   809  	p.P(`}`)
   810  
   811  }
   812  
   813  func init() {
   814  	generator.RegisterPlugin(NewPlugin())
   815  }
   816  

View as plain text