...

Source file src/github.com/gogo/protobuf/plugin/size/size.go

Documentation: github.com/gogo/protobuf/plugin/size

     1  // Protocol Buffers for Go with Gadgets
     2  //
     3  // Copyright (c) 2013, The GoGo Authors. All rights reserved.
     4  // http://github.com/gogo/protobuf
     5  //
     6  // Redistribution and use in source and binary forms, with or without
     7  // modification, are permitted provided that the following conditions are
     8  // met:
     9  //
    10  //     * Redistributions of source code must retain the above copyright
    11  // notice, this list of conditions and the following disclaimer.
    12  //     * Redistributions in binary form must reproduce the above
    13  // copyright notice, this list of conditions and the following disclaimer
    14  // in the documentation and/or other materials provided with the
    15  // distribution.
    16  //
    17  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    18  // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    19  // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    20  // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
    21  // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    22  // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    23  // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
    24  // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
    25  // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    26  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    27  // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    28  
    29  /*
    30  The size plugin generates a Size or ProtoSize method for each message.
    31  This is useful with the MarshalTo method generated by the marshalto plugin and the
    32  gogoproto.marshaler and gogoproto.marshaler_all extensions.
    33  
    34  It is enabled by the following extensions:
    35  
    36    - sizer
    37    - sizer_all
    38    - protosizer
    39    - protosizer_all
    40  
    41  The size plugin also generates a test given it is enabled using one of the following extensions:
    42  
    43    - testgen
    44    - testgen_all
    45  
    46  And a benchmark given it is enabled using one of the following extensions:
    47  
    48    - benchgen
    49    - benchgen_all
    50  
    51  Let us look at:
    52  
    53    github.com/gogo/protobuf/test/example/example.proto
    54  
    55  Btw all the output can be seen at:
    56  
    57    github.com/gogo/protobuf/test/example/*
    58  
    59  The following message:
    60  
    61    option (gogoproto.sizer_all) = true;
    62  
    63    message B {
    64  	option (gogoproto.description) = true;
    65  	optional A A = 1 [(gogoproto.nullable) = false, (gogoproto.embed) = true];
    66  	repeated bytes G = 2 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uint128", (gogoproto.nullable) = false];
    67    }
    68  
    69  given to the size plugin, will generate the following code:
    70  
    71    func (m *B) Size() (n int) {
    72  	if m == nil {
    73  		return 0
    74  	}
    75  	var l int
    76  	_ = l
    77  	l = m.A.Size()
    78  	n += 1 + l + sovExample(uint64(l))
    79  	if len(m.G) > 0 {
    80  		for _, e := range m.G {
    81  			l = e.Size()
    82  			n += 1 + l + sovExample(uint64(l))
    83  		}
    84  	}
    85  	if m.XXX_unrecognized != nil {
    86  		n += len(m.XXX_unrecognized)
    87  	}
    88  	return n
    89    }
    90  
    91  and the following test code:
    92  
    93  	func TestBSize(t *testing5.T) {
    94  		popr := math_rand5.New(math_rand5.NewSource(time5.Now().UnixNano()))
    95  		p := NewPopulatedB(popr, true)
    96  		dAtA, err := github_com_gogo_protobuf_proto2.Marshal(p)
    97  		if err != nil {
    98  			panic(err)
    99  		}
   100  		size := p.Size()
   101  		if len(dAtA) != size {
   102  			t.Fatalf("size %v != marshalled size %v", size, len(dAtA))
   103  		}
   104  	}
   105  
   106  	func BenchmarkBSize(b *testing5.B) {
   107  		popr := math_rand5.New(math_rand5.NewSource(616))
   108  		total := 0
   109  		pops := make([]*B, 1000)
   110  		for i := 0; i < 1000; i++ {
   111  			pops[i] = NewPopulatedB(popr, false)
   112  		}
   113  		b.ResetTimer()
   114  		for i := 0; i < b.N; i++ {
   115  			total += pops[i%1000].Size()
   116  		}
   117  		b.SetBytes(int64(total / b.N))
   118  	}
   119  
   120  The sovExample function is a size of varint function for the example.pb.go file.
   121  
   122  */
   123  package size
   124  
   125  import (
   126  	"fmt"
   127  	"os"
   128  	"strconv"
   129  	"strings"
   130  
   131  	"github.com/gogo/protobuf/gogoproto"
   132  	"github.com/gogo/protobuf/proto"
   133  	descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
   134  	"github.com/gogo/protobuf/protoc-gen-gogo/generator"
   135  	"github.com/gogo/protobuf/vanity"
   136  )
   137  
   138  type size struct {
   139  	*generator.Generator
   140  	generator.PluginImports
   141  	atleastOne bool
   142  	localName  string
   143  	typesPkg   generator.Single
   144  	bitsPkg    generator.Single
   145  }
   146  
   147  func NewSize() *size {
   148  	return &size{}
   149  }
   150  
   151  func (p *size) Name() string {
   152  	return "size"
   153  }
   154  
   155  func (p *size) Init(g *generator.Generator) {
   156  	p.Generator = g
   157  }
   158  
   159  func wireToType(wire string) int {
   160  	switch wire {
   161  	case "fixed64":
   162  		return proto.WireFixed64
   163  	case "fixed32":
   164  		return proto.WireFixed32
   165  	case "varint":
   166  		return proto.WireVarint
   167  	case "bytes":
   168  		return proto.WireBytes
   169  	case "group":
   170  		return proto.WireBytes
   171  	case "zigzag32":
   172  		return proto.WireVarint
   173  	case "zigzag64":
   174  		return proto.WireVarint
   175  	}
   176  	panic("unreachable")
   177  }
   178  
   179  func keySize(fieldNumber int32, wireType int) int {
   180  	x := uint32(fieldNumber)<<3 | uint32(wireType)
   181  	size := 0
   182  	for size = 0; x > 127; size++ {
   183  		x >>= 7
   184  	}
   185  	size++
   186  	return size
   187  }
   188  
   189  func (p *size) sizeVarint() {
   190  	p.P(`
   191  	func sov`, p.localName, `(x uint64) (n int) {
   192                  return (`, p.bitsPkg.Use(), `.Len64(x | 1) + 6)/ 7
   193  	}`)
   194  }
   195  
   196  func (p *size) sizeZigZag() {
   197  	p.P(`func soz`, p.localName, `(x uint64) (n int) {
   198  		return sov`, p.localName, `(uint64((x << 1) ^ uint64((int64(x) >> 63))))
   199  	}`)
   200  }
   201  
   202  func (p *size) std(field *descriptor.FieldDescriptorProto, name string) (string, bool) {
   203  	ptr := ""
   204  	if gogoproto.IsNullable(field) {
   205  		ptr = "*"
   206  	}
   207  	if gogoproto.IsStdTime(field) {
   208  		return p.typesPkg.Use() + `.SizeOfStdTime(` + ptr + name + `)`, true
   209  	} else if gogoproto.IsStdDuration(field) {
   210  		return p.typesPkg.Use() + `.SizeOfStdDuration(` + ptr + name + `)`, true
   211  	} else if gogoproto.IsStdDouble(field) {
   212  		return p.typesPkg.Use() + `.SizeOfStdDouble(` + ptr + name + `)`, true
   213  	} else if gogoproto.IsStdFloat(field) {
   214  		return p.typesPkg.Use() + `.SizeOfStdFloat(` + ptr + name + `)`, true
   215  	} else if gogoproto.IsStdInt64(field) {
   216  		return p.typesPkg.Use() + `.SizeOfStdInt64(` + ptr + name + `)`, true
   217  	} else if gogoproto.IsStdUInt64(field) {
   218  		return p.typesPkg.Use() + `.SizeOfStdUInt64(` + ptr + name + `)`, true
   219  	} else if gogoproto.IsStdInt32(field) {
   220  		return p.typesPkg.Use() + `.SizeOfStdInt32(` + ptr + name + `)`, true
   221  	} else if gogoproto.IsStdUInt32(field) {
   222  		return p.typesPkg.Use() + `.SizeOfStdUInt32(` + ptr + name + `)`, true
   223  	} else if gogoproto.IsStdBool(field) {
   224  		return p.typesPkg.Use() + `.SizeOfStdBool(` + ptr + name + `)`, true
   225  	} else if gogoproto.IsStdString(field) {
   226  		return p.typesPkg.Use() + `.SizeOfStdString(` + ptr + name + `)`, true
   227  	} else if gogoproto.IsStdBytes(field) {
   228  		return p.typesPkg.Use() + `.SizeOfStdBytes(` + ptr + name + `)`, true
   229  	}
   230  	return "", false
   231  }
   232  
   233  func (p *size) generateField(proto3 bool, file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto, sizeName string) {
   234  	fieldname := p.GetOneOfFieldName(message, field)
   235  	nullable := gogoproto.IsNullable(field)
   236  	repeated := field.IsRepeated()
   237  	doNilCheck := gogoproto.NeedsNilCheck(proto3, field)
   238  	if repeated {
   239  		p.P(`if len(m.`, fieldname, `) > 0 {`)
   240  		p.In()
   241  	} else if doNilCheck {
   242  		p.P(`if m.`, fieldname, ` != nil {`)
   243  		p.In()
   244  	}
   245  	packed := field.IsPacked() || (proto3 && field.IsPacked3())
   246  	_, wire := p.GoType(message, field)
   247  	wireType := wireToType(wire)
   248  	fieldNumber := field.GetNumber()
   249  	if packed {
   250  		wireType = proto.WireBytes
   251  	}
   252  	key := keySize(fieldNumber, wireType)
   253  	switch *field.Type {
   254  	case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
   255  		descriptor.FieldDescriptorProto_TYPE_FIXED64,
   256  		descriptor.FieldDescriptorProto_TYPE_SFIXED64:
   257  		if packed {
   258  			p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(len(m.`, fieldname, `)*8))`, `+len(m.`, fieldname, `)*8`)
   259  		} else if repeated {
   260  			p.P(`n+=`, strconv.Itoa(key+8), `*len(m.`, fieldname, `)`)
   261  		} else if proto3 {
   262  			p.P(`if m.`, fieldname, ` != 0 {`)
   263  			p.In()
   264  			p.P(`n+=`, strconv.Itoa(key+8))
   265  			p.Out()
   266  			p.P(`}`)
   267  		} else if nullable {
   268  			p.P(`n+=`, strconv.Itoa(key+8))
   269  		} else {
   270  			p.P(`n+=`, strconv.Itoa(key+8))
   271  		}
   272  	case descriptor.FieldDescriptorProto_TYPE_FLOAT,
   273  		descriptor.FieldDescriptorProto_TYPE_FIXED32,
   274  		descriptor.FieldDescriptorProto_TYPE_SFIXED32:
   275  		if packed {
   276  			p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(len(m.`, fieldname, `)*4))`, `+len(m.`, fieldname, `)*4`)
   277  		} else if repeated {
   278  			p.P(`n+=`, strconv.Itoa(key+4), `*len(m.`, fieldname, `)`)
   279  		} else if proto3 {
   280  			p.P(`if m.`, fieldname, ` != 0 {`)
   281  			p.In()
   282  			p.P(`n+=`, strconv.Itoa(key+4))
   283  			p.Out()
   284  			p.P(`}`)
   285  		} else if nullable {
   286  			p.P(`n+=`, strconv.Itoa(key+4))
   287  		} else {
   288  			p.P(`n+=`, strconv.Itoa(key+4))
   289  		}
   290  	case descriptor.FieldDescriptorProto_TYPE_INT64,
   291  		descriptor.FieldDescriptorProto_TYPE_UINT64,
   292  		descriptor.FieldDescriptorProto_TYPE_UINT32,
   293  		descriptor.FieldDescriptorProto_TYPE_ENUM,
   294  		descriptor.FieldDescriptorProto_TYPE_INT32:
   295  		if packed {
   296  			p.P(`l = 0`)
   297  			p.P(`for _, e := range m.`, fieldname, ` {`)
   298  			p.In()
   299  			p.P(`l+=sov`, p.localName, `(uint64(e))`)
   300  			p.Out()
   301  			p.P(`}`)
   302  			p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(l))+l`)
   303  		} else if repeated {
   304  			p.P(`for _, e := range m.`, fieldname, ` {`)
   305  			p.In()
   306  			p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(e))`)
   307  			p.Out()
   308  			p.P(`}`)
   309  		} else if proto3 {
   310  			p.P(`if m.`, fieldname, ` != 0 {`)
   311  			p.In()
   312  			p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(m.`, fieldname, `))`)
   313  			p.Out()
   314  			p.P(`}`)
   315  		} else if nullable {
   316  			p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(*m.`, fieldname, `))`)
   317  		} else {
   318  			p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(m.`, fieldname, `))`)
   319  		}
   320  	case descriptor.FieldDescriptorProto_TYPE_BOOL:
   321  		if packed {
   322  			p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(len(m.`, fieldname, `)))`, `+len(m.`, fieldname, `)*1`)
   323  		} else if repeated {
   324  			p.P(`n+=`, strconv.Itoa(key+1), `*len(m.`, fieldname, `)`)
   325  		} else if proto3 {
   326  			p.P(`if m.`, fieldname, ` {`)
   327  			p.In()
   328  			p.P(`n+=`, strconv.Itoa(key+1))
   329  			p.Out()
   330  			p.P(`}`)
   331  		} else if nullable {
   332  			p.P(`n+=`, strconv.Itoa(key+1))
   333  		} else {
   334  			p.P(`n+=`, strconv.Itoa(key+1))
   335  		}
   336  	case descriptor.FieldDescriptorProto_TYPE_STRING:
   337  		if repeated {
   338  			p.P(`for _, s := range m.`, fieldname, ` { `)
   339  			p.In()
   340  			p.P(`l = len(s)`)
   341  			p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
   342  			p.Out()
   343  			p.P(`}`)
   344  		} else if proto3 {
   345  			p.P(`l=len(m.`, fieldname, `)`)
   346  			p.P(`if l > 0 {`)
   347  			p.In()
   348  			p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
   349  			p.Out()
   350  			p.P(`}`)
   351  		} else if nullable {
   352  			p.P(`l=len(*m.`, fieldname, `)`)
   353  			p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
   354  		} else {
   355  			p.P(`l=len(m.`, fieldname, `)`)
   356  			p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
   357  		}
   358  	case descriptor.FieldDescriptorProto_TYPE_GROUP:
   359  		panic(fmt.Errorf("size does not support group %v", fieldname))
   360  	case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
   361  		if p.IsMap(field) {
   362  			m := p.GoMapType(nil, field)
   363  			_, keywire := p.GoType(nil, m.KeyAliasField)
   364  			valuegoTyp, _ := p.GoType(nil, m.ValueField)
   365  			valuegoAliasTyp, valuewire := p.GoType(nil, m.ValueAliasField)
   366  			_, fieldwire := p.GoType(nil, field)
   367  
   368  			nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp)
   369  
   370  			fieldKeySize := keySize(field.GetNumber(), wireToType(fieldwire))
   371  			keyKeySize := keySize(1, wireToType(keywire))
   372  			valueKeySize := keySize(2, wireToType(valuewire))
   373  			p.P(`for k, v := range m.`, fieldname, ` { `)
   374  			p.In()
   375  			p.P(`_ = k`)
   376  			p.P(`_ = v`)
   377  			sum := []string{strconv.Itoa(keyKeySize)}
   378  			switch m.KeyField.GetType() {
   379  			case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
   380  				descriptor.FieldDescriptorProto_TYPE_FIXED64,
   381  				descriptor.FieldDescriptorProto_TYPE_SFIXED64:
   382  				sum = append(sum, `8`)
   383  			case descriptor.FieldDescriptorProto_TYPE_FLOAT,
   384  				descriptor.FieldDescriptorProto_TYPE_FIXED32,
   385  				descriptor.FieldDescriptorProto_TYPE_SFIXED32:
   386  				sum = append(sum, `4`)
   387  			case descriptor.FieldDescriptorProto_TYPE_INT64,
   388  				descriptor.FieldDescriptorProto_TYPE_UINT64,
   389  				descriptor.FieldDescriptorProto_TYPE_UINT32,
   390  				descriptor.FieldDescriptorProto_TYPE_ENUM,
   391  				descriptor.FieldDescriptorProto_TYPE_INT32:
   392  				sum = append(sum, `sov`+p.localName+`(uint64(k))`)
   393  			case descriptor.FieldDescriptorProto_TYPE_BOOL:
   394  				sum = append(sum, `1`)
   395  			case descriptor.FieldDescriptorProto_TYPE_STRING,
   396  				descriptor.FieldDescriptorProto_TYPE_BYTES:
   397  				sum = append(sum, `len(k)+sov`+p.localName+`(uint64(len(k)))`)
   398  			case descriptor.FieldDescriptorProto_TYPE_SINT32,
   399  				descriptor.FieldDescriptorProto_TYPE_SINT64:
   400  				sum = append(sum, `soz`+p.localName+`(uint64(k))`)
   401  			}
   402  			switch m.ValueField.GetType() {
   403  			case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
   404  				descriptor.FieldDescriptorProto_TYPE_FIXED64,
   405  				descriptor.FieldDescriptorProto_TYPE_SFIXED64:
   406  				sum = append(sum, strconv.Itoa(valueKeySize))
   407  				sum = append(sum, strconv.Itoa(8))
   408  			case descriptor.FieldDescriptorProto_TYPE_FLOAT,
   409  				descriptor.FieldDescriptorProto_TYPE_FIXED32,
   410  				descriptor.FieldDescriptorProto_TYPE_SFIXED32:
   411  				sum = append(sum, strconv.Itoa(valueKeySize))
   412  				sum = append(sum, strconv.Itoa(4))
   413  			case descriptor.FieldDescriptorProto_TYPE_INT64,
   414  				descriptor.FieldDescriptorProto_TYPE_UINT64,
   415  				descriptor.FieldDescriptorProto_TYPE_UINT32,
   416  				descriptor.FieldDescriptorProto_TYPE_ENUM,
   417  				descriptor.FieldDescriptorProto_TYPE_INT32:
   418  				sum = append(sum, strconv.Itoa(valueKeySize))
   419  				sum = append(sum, `sov`+p.localName+`(uint64(v))`)
   420  			case descriptor.FieldDescriptorProto_TYPE_BOOL:
   421  				sum = append(sum, strconv.Itoa(valueKeySize))
   422  				sum = append(sum, `1`)
   423  			case descriptor.FieldDescriptorProto_TYPE_STRING:
   424  				sum = append(sum, strconv.Itoa(valueKeySize))
   425  				sum = append(sum, `len(v)+sov`+p.localName+`(uint64(len(v)))`)
   426  			case descriptor.FieldDescriptorProto_TYPE_BYTES:
   427  				if gogoproto.IsCustomType(field) {
   428  					p.P(`l = 0`)
   429  					if nullable {
   430  						p.P(`if v != nil {`)
   431  						p.In()
   432  					}
   433  					p.P(`l = v.`, sizeName, `()`)
   434  					p.P(`l += `, strconv.Itoa(valueKeySize), ` + sov`+p.localName+`(uint64(l))`)
   435  					if nullable {
   436  						p.Out()
   437  						p.P(`}`)
   438  					}
   439  					sum = append(sum, `l`)
   440  				} else {
   441  					p.P(`l = 0`)
   442  					if proto3 {
   443  						p.P(`if len(v) > 0 {`)
   444  					} else {
   445  						p.P(`if v != nil {`)
   446  					}
   447  					p.In()
   448  					p.P(`l = `, strconv.Itoa(valueKeySize), ` + len(v)+sov`+p.localName+`(uint64(len(v)))`)
   449  					p.Out()
   450  					p.P(`}`)
   451  					sum = append(sum, `l`)
   452  				}
   453  			case descriptor.FieldDescriptorProto_TYPE_SINT32,
   454  				descriptor.FieldDescriptorProto_TYPE_SINT64:
   455  				sum = append(sum, strconv.Itoa(valueKeySize))
   456  				sum = append(sum, `soz`+p.localName+`(uint64(v))`)
   457  			case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
   458  				stdSizeCall, stdOk := p.std(m.ValueAliasField, "v")
   459  				if nullable {
   460  					p.P(`l = 0`)
   461  					p.P(`if v != nil {`)
   462  					p.In()
   463  					if stdOk {
   464  						p.P(`l = `, stdSizeCall)
   465  					} else if valuegoTyp != valuegoAliasTyp {
   466  						p.P(`l = ((`, valuegoTyp, `)(v)).`, sizeName, `()`)
   467  					} else {
   468  						p.P(`l = v.`, sizeName, `()`)
   469  					}
   470  					p.P(`l += `, strconv.Itoa(valueKeySize), ` + sov`+p.localName+`(uint64(l))`)
   471  					p.Out()
   472  					p.P(`}`)
   473  					sum = append(sum, `l`)
   474  				} else {
   475  					if stdOk {
   476  						p.P(`l = `, stdSizeCall)
   477  					} else if valuegoTyp != valuegoAliasTyp {
   478  						p.P(`l = ((*`, valuegoTyp, `)(&v)).`, sizeName, `()`)
   479  					} else {
   480  						p.P(`l = v.`, sizeName, `()`)
   481  					}
   482  					sum = append(sum, strconv.Itoa(valueKeySize))
   483  					sum = append(sum, `l+sov`+p.localName+`(uint64(l))`)
   484  				}
   485  			}
   486  			p.P(`mapEntrySize := `, strings.Join(sum, "+"))
   487  			p.P(`n+=mapEntrySize+`, fieldKeySize, `+sov`, p.localName, `(uint64(mapEntrySize))`)
   488  			p.Out()
   489  			p.P(`}`)
   490  		} else if repeated {
   491  			p.P(`for _, e := range m.`, fieldname, ` { `)
   492  			p.In()
   493  			stdSizeCall, stdOk := p.std(field, "e")
   494  			if stdOk {
   495  				p.P(`l=`, stdSizeCall)
   496  			} else {
   497  				p.P(`l=e.`, sizeName, `()`)
   498  			}
   499  			p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
   500  			p.Out()
   501  			p.P(`}`)
   502  		} else {
   503  			stdSizeCall, stdOk := p.std(field, "m."+fieldname)
   504  			if stdOk {
   505  				p.P(`l=`, stdSizeCall)
   506  			} else {
   507  				p.P(`l=m.`, fieldname, `.`, sizeName, `()`)
   508  			}
   509  			p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
   510  		}
   511  	case descriptor.FieldDescriptorProto_TYPE_BYTES:
   512  		if !gogoproto.IsCustomType(field) {
   513  			if repeated {
   514  				p.P(`for _, b := range m.`, fieldname, ` { `)
   515  				p.In()
   516  				p.P(`l = len(b)`)
   517  				p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
   518  				p.Out()
   519  				p.P(`}`)
   520  			} else if proto3 {
   521  				p.P(`l=len(m.`, fieldname, `)`)
   522  				p.P(`if l > 0 {`)
   523  				p.In()
   524  				p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
   525  				p.Out()
   526  				p.P(`}`)
   527  			} else {
   528  				p.P(`l=len(m.`, fieldname, `)`)
   529  				p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
   530  			}
   531  		} else {
   532  			if repeated {
   533  				p.P(`for _, e := range m.`, fieldname, ` { `)
   534  				p.In()
   535  				p.P(`l=e.`, sizeName, `()`)
   536  				p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
   537  				p.Out()
   538  				p.P(`}`)
   539  			} else {
   540  				p.P(`l=m.`, fieldname, `.`, sizeName, `()`)
   541  				p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
   542  			}
   543  		}
   544  	case descriptor.FieldDescriptorProto_TYPE_SINT32,
   545  		descriptor.FieldDescriptorProto_TYPE_SINT64:
   546  		if packed {
   547  			p.P(`l = 0`)
   548  			p.P(`for _, e := range m.`, fieldname, ` {`)
   549  			p.In()
   550  			p.P(`l+=soz`, p.localName, `(uint64(e))`)
   551  			p.Out()
   552  			p.P(`}`)
   553  			p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(l))+l`)
   554  		} else if repeated {
   555  			p.P(`for _, e := range m.`, fieldname, ` {`)
   556  			p.In()
   557  			p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(e))`)
   558  			p.Out()
   559  			p.P(`}`)
   560  		} else if proto3 {
   561  			p.P(`if m.`, fieldname, ` != 0 {`)
   562  			p.In()
   563  			p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(m.`, fieldname, `))`)
   564  			p.Out()
   565  			p.P(`}`)
   566  		} else if nullable {
   567  			p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(*m.`, fieldname, `))`)
   568  		} else {
   569  			p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(m.`, fieldname, `))`)
   570  		}
   571  	default:
   572  		panic("not implemented")
   573  	}
   574  	if repeated || doNilCheck {
   575  		p.Out()
   576  		p.P(`}`)
   577  	}
   578  }
   579  
   580  func (p *size) Generate(file *generator.FileDescriptor) {
   581  	p.PluginImports = generator.NewPluginImports(p.Generator)
   582  	p.atleastOne = false
   583  	p.localName = generator.FileName(file)
   584  	p.typesPkg = p.NewImport("github.com/gogo/protobuf/types")
   585  	protoPkg := p.NewImport("github.com/gogo/protobuf/proto")
   586  	p.bitsPkg = p.NewImport("math/bits")
   587  	if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) {
   588  		protoPkg = p.NewImport("github.com/golang/protobuf/proto")
   589  	}
   590  	for _, message := range file.Messages() {
   591  		sizeName := ""
   592  		if gogoproto.IsSizer(file.FileDescriptorProto, message.DescriptorProto) && gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) {
   593  			fmt.Fprintf(os.Stderr, "ERROR: message %v cannot support both sizer and protosizer plugins\n", generator.CamelCase(*message.Name))
   594  			os.Exit(1)
   595  		}
   596  		if gogoproto.IsSizer(file.FileDescriptorProto, message.DescriptorProto) {
   597  			sizeName = "Size"
   598  		} else if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) {
   599  			sizeName = "ProtoSize"
   600  		} else {
   601  			continue
   602  		}
   603  		if message.DescriptorProto.GetOptions().GetMapEntry() {
   604  			continue
   605  		}
   606  		p.atleastOne = true
   607  		ccTypeName := generator.CamelCaseSlice(message.TypeName())
   608  		p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`)
   609  		p.In()
   610  		p.P(`if m == nil {`)
   611  		p.In()
   612  		p.P(`return 0`)
   613  		p.Out()
   614  		p.P(`}`)
   615  		p.P(`var l int`)
   616  		p.P(`_ = l`)
   617  		oneofs := make(map[string]struct{})
   618  		for _, field := range message.Field {
   619  			oneof := field.OneofIndex != nil
   620  			if !oneof {
   621  				proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
   622  				p.generateField(proto3, file, message, field, sizeName)
   623  			} else {
   624  				fieldname := p.GetFieldName(message, field)
   625  				if _, ok := oneofs[fieldname]; ok {
   626  					continue
   627  				} else {
   628  					oneofs[fieldname] = struct{}{}
   629  				}
   630  				p.P(`if m.`, fieldname, ` != nil {`)
   631  				p.In()
   632  				p.P(`n+=m.`, fieldname, `.`, sizeName, `()`)
   633  				p.Out()
   634  				p.P(`}`)
   635  			}
   636  		}
   637  		if message.DescriptorProto.HasExtension() {
   638  			if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) {
   639  				p.P(`n += `, protoPkg.Use(), `.SizeOfInternalExtension(m)`)
   640  			} else {
   641  				p.P(`if m.XXX_extensions != nil {`)
   642  				p.In()
   643  				p.P(`n+=len(m.XXX_extensions)`)
   644  				p.Out()
   645  				p.P(`}`)
   646  			}
   647  		}
   648  		if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) {
   649  			p.P(`if m.XXX_unrecognized != nil {`)
   650  			p.In()
   651  			p.P(`n+=len(m.XXX_unrecognized)`)
   652  			p.Out()
   653  			p.P(`}`)
   654  		}
   655  		p.P(`return n`)
   656  		p.Out()
   657  		p.P(`}`)
   658  		p.P()
   659  
   660  		//Generate Size methods for oneof fields
   661  		m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto)
   662  		for _, f := range m.Field {
   663  			oneof := f.OneofIndex != nil
   664  			if !oneof {
   665  				continue
   666  			}
   667  			ccTypeName := p.OneOfTypeName(message, f)
   668  			p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`)
   669  			p.In()
   670  			p.P(`if m == nil {`)
   671  			p.In()
   672  			p.P(`return 0`)
   673  			p.Out()
   674  			p.P(`}`)
   675  			p.P(`var l int`)
   676  			p.P(`_ = l`)
   677  			vanity.TurnOffNullableForNativeTypes(f)
   678  			p.generateField(false, file, message, f, sizeName)
   679  			p.P(`return n`)
   680  			p.Out()
   681  			p.P(`}`)
   682  		}
   683  	}
   684  
   685  	if !p.atleastOne {
   686  		return
   687  	}
   688  
   689  	p.sizeVarint()
   690  	p.sizeZigZag()
   691  
   692  }
   693  
   694  func init() {
   695  	generator.RegisterPlugin(NewSize())
   696  }
   697  

View as plain text