...

Source file src/github.com/gogo/protobuf/proto/extensions_gogo.go

Documentation: github.com/gogo/protobuf/proto

     1  // Protocol Buffers for Go with Gadgets
     2  //
     3  // Copyright (c) 2013, The GoGo Authors. All rights reserved.
     4  // http://github.com/gogo/protobuf
     5  //
     6  // Redistribution and use in source and binary forms, with or without
     7  // modification, are permitted provided that the following conditions are
     8  // met:
     9  //
    10  //     * Redistributions of source code must retain the above copyright
    11  // notice, this list of conditions and the following disclaimer.
    12  //     * Redistributions in binary form must reproduce the above
    13  // copyright notice, this list of conditions and the following disclaimer
    14  // in the documentation and/or other materials provided with the
    15  // distribution.
    16  //
    17  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    18  // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    19  // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    20  // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
    21  // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    22  // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    23  // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
    24  // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
    25  // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    26  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    27  // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    28  
    29  package proto
    30  
    31  import (
    32  	"bytes"
    33  	"errors"
    34  	"fmt"
    35  	"io"
    36  	"reflect"
    37  	"sort"
    38  	"strings"
    39  	"sync"
    40  )
    41  
    42  type extensionsBytes interface {
    43  	Message
    44  	ExtensionRangeArray() []ExtensionRange
    45  	GetExtensions() *[]byte
    46  }
    47  
    48  type slowExtensionAdapter struct {
    49  	extensionsBytes
    50  }
    51  
    52  func (s slowExtensionAdapter) extensionsWrite() map[int32]Extension {
    53  	panic("Please report a bug to github.com/gogo/protobuf if you see this message: Writing extensions is not supported for extensions stored in a byte slice field.")
    54  }
    55  
    56  func (s slowExtensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
    57  	b := s.GetExtensions()
    58  	m, err := BytesToExtensionsMap(*b)
    59  	if err != nil {
    60  		panic(err)
    61  	}
    62  	return m, notLocker{}
    63  }
    64  
    65  func GetBoolExtension(pb Message, extension *ExtensionDesc, ifnotset bool) bool {
    66  	if reflect.ValueOf(pb).IsNil() {
    67  		return ifnotset
    68  	}
    69  	value, err := GetExtension(pb, extension)
    70  	if err != nil {
    71  		return ifnotset
    72  	}
    73  	if value == nil {
    74  		return ifnotset
    75  	}
    76  	if value.(*bool) == nil {
    77  		return ifnotset
    78  	}
    79  	return *(value.(*bool))
    80  }
    81  
    82  func (this *Extension) Equal(that *Extension) bool {
    83  	if err := this.Encode(); err != nil {
    84  		return false
    85  	}
    86  	if err := that.Encode(); err != nil {
    87  		return false
    88  	}
    89  	return bytes.Equal(this.enc, that.enc)
    90  }
    91  
    92  func (this *Extension) Compare(that *Extension) int {
    93  	if err := this.Encode(); err != nil {
    94  		return 1
    95  	}
    96  	if err := that.Encode(); err != nil {
    97  		return -1
    98  	}
    99  	return bytes.Compare(this.enc, that.enc)
   100  }
   101  
   102  func SizeOfInternalExtension(m extendableProto) (n int) {
   103  	info := getMarshalInfo(reflect.TypeOf(m))
   104  	return info.sizeV1Extensions(m.extensionsWrite())
   105  }
   106  
   107  type sortableMapElem struct {
   108  	field int32
   109  	ext   Extension
   110  }
   111  
   112  func newSortableExtensionsFromMap(m map[int32]Extension) sortableExtensions {
   113  	s := make(sortableExtensions, 0, len(m))
   114  	for k, v := range m {
   115  		s = append(s, &sortableMapElem{field: k, ext: v})
   116  	}
   117  	return s
   118  }
   119  
   120  type sortableExtensions []*sortableMapElem
   121  
   122  func (this sortableExtensions) Len() int { return len(this) }
   123  
   124  func (this sortableExtensions) Swap(i, j int) { this[i], this[j] = this[j], this[i] }
   125  
   126  func (this sortableExtensions) Less(i, j int) bool { return this[i].field < this[j].field }
   127  
   128  func (this sortableExtensions) String() string {
   129  	sort.Sort(this)
   130  	ss := make([]string, len(this))
   131  	for i := range this {
   132  		ss[i] = fmt.Sprintf("%d: %v", this[i].field, this[i].ext)
   133  	}
   134  	return "map[" + strings.Join(ss, ",") + "]"
   135  }
   136  
   137  func StringFromInternalExtension(m extendableProto) string {
   138  	return StringFromExtensionsMap(m.extensionsWrite())
   139  }
   140  
   141  func StringFromExtensionsMap(m map[int32]Extension) string {
   142  	return newSortableExtensionsFromMap(m).String()
   143  }
   144  
   145  func StringFromExtensionsBytes(ext []byte) string {
   146  	m, err := BytesToExtensionsMap(ext)
   147  	if err != nil {
   148  		panic(err)
   149  	}
   150  	return StringFromExtensionsMap(m)
   151  }
   152  
   153  func EncodeInternalExtension(m extendableProto, data []byte) (n int, err error) {
   154  	return EncodeExtensionMap(m.extensionsWrite(), data)
   155  }
   156  
   157  func EncodeInternalExtensionBackwards(m extendableProto, data []byte) (n int, err error) {
   158  	return EncodeExtensionMapBackwards(m.extensionsWrite(), data)
   159  }
   160  
   161  func EncodeExtensionMap(m map[int32]Extension, data []byte) (n int, err error) {
   162  	o := 0
   163  	for _, e := range m {
   164  		if err := e.Encode(); err != nil {
   165  			return 0, err
   166  		}
   167  		n := copy(data[o:], e.enc)
   168  		if n != len(e.enc) {
   169  			return 0, io.ErrShortBuffer
   170  		}
   171  		o += n
   172  	}
   173  	return o, nil
   174  }
   175  
   176  func EncodeExtensionMapBackwards(m map[int32]Extension, data []byte) (n int, err error) {
   177  	o := 0
   178  	end := len(data)
   179  	for _, e := range m {
   180  		if err := e.Encode(); err != nil {
   181  			return 0, err
   182  		}
   183  		n := copy(data[end-len(e.enc):], e.enc)
   184  		if n != len(e.enc) {
   185  			return 0, io.ErrShortBuffer
   186  		}
   187  		end -= n
   188  		o += n
   189  	}
   190  	return o, nil
   191  }
   192  
   193  func GetRawExtension(m map[int32]Extension, id int32) ([]byte, error) {
   194  	e := m[id]
   195  	if err := e.Encode(); err != nil {
   196  		return nil, err
   197  	}
   198  	return e.enc, nil
   199  }
   200  
   201  func size(buf []byte, wire int) (int, error) {
   202  	switch wire {
   203  	case WireVarint:
   204  		_, n := DecodeVarint(buf)
   205  		return n, nil
   206  	case WireFixed64:
   207  		return 8, nil
   208  	case WireBytes:
   209  		v, n := DecodeVarint(buf)
   210  		return int(v) + n, nil
   211  	case WireFixed32:
   212  		return 4, nil
   213  	case WireStartGroup:
   214  		offset := 0
   215  		for {
   216  			u, n := DecodeVarint(buf[offset:])
   217  			fwire := int(u & 0x7)
   218  			offset += n
   219  			if fwire == WireEndGroup {
   220  				return offset, nil
   221  			}
   222  			s, err := size(buf[offset:], wire)
   223  			if err != nil {
   224  				return 0, err
   225  			}
   226  			offset += s
   227  		}
   228  	}
   229  	return 0, fmt.Errorf("proto: can't get size for unknown wire type %d", wire)
   230  }
   231  
   232  func BytesToExtensionsMap(buf []byte) (map[int32]Extension, error) {
   233  	m := make(map[int32]Extension)
   234  	i := 0
   235  	for i < len(buf) {
   236  		tag, n := DecodeVarint(buf[i:])
   237  		if n <= 0 {
   238  			return nil, fmt.Errorf("unable to decode varint")
   239  		}
   240  		fieldNum := int32(tag >> 3)
   241  		wireType := int(tag & 0x7)
   242  		l, err := size(buf[i+n:], wireType)
   243  		if err != nil {
   244  			return nil, err
   245  		}
   246  		end := i + int(l) + n
   247  		m[int32(fieldNum)] = Extension{enc: buf[i:end]}
   248  		i = end
   249  	}
   250  	return m, nil
   251  }
   252  
   253  func NewExtension(e []byte) Extension {
   254  	ee := Extension{enc: make([]byte, len(e))}
   255  	copy(ee.enc, e)
   256  	return ee
   257  }
   258  
   259  func AppendExtension(e Message, tag int32, buf []byte) {
   260  	if ee, eok := e.(extensionsBytes); eok {
   261  		ext := ee.GetExtensions()
   262  		*ext = append(*ext, buf...)
   263  		return
   264  	}
   265  	if ee, eok := e.(extendableProto); eok {
   266  		m := ee.extensionsWrite()
   267  		ext := m[int32(tag)] // may be missing
   268  		ext.enc = append(ext.enc, buf...)
   269  		m[int32(tag)] = ext
   270  	}
   271  }
   272  
   273  func encodeExtension(extension *ExtensionDesc, value interface{}) ([]byte, error) {
   274  	u := getMarshalInfo(reflect.TypeOf(extension.ExtendedType))
   275  	ei := u.getExtElemInfo(extension)
   276  	v := value
   277  	p := toAddrPointer(&v, ei.isptr)
   278  	siz := ei.sizer(p, SizeVarint(ei.wiretag))
   279  	buf := make([]byte, 0, siz)
   280  	return ei.marshaler(buf, p, ei.wiretag, false)
   281  }
   282  
   283  func decodeExtensionFromBytes(extension *ExtensionDesc, buf []byte) (interface{}, error) {
   284  	o := 0
   285  	for o < len(buf) {
   286  		tag, n := DecodeVarint((buf)[o:])
   287  		fieldNum := int32(tag >> 3)
   288  		wireType := int(tag & 0x7)
   289  		if o+n > len(buf) {
   290  			return nil, fmt.Errorf("unable to decode extension")
   291  		}
   292  		l, err := size((buf)[o+n:], wireType)
   293  		if err != nil {
   294  			return nil, err
   295  		}
   296  		if int32(fieldNum) == extension.Field {
   297  			if o+n+l > len(buf) {
   298  				return nil, fmt.Errorf("unable to decode extension")
   299  			}
   300  			v, err := decodeExtension((buf)[o:o+n+l], extension)
   301  			if err != nil {
   302  				return nil, err
   303  			}
   304  			return v, nil
   305  		}
   306  		o += n + l
   307  	}
   308  	return defaultExtensionValue(extension)
   309  }
   310  
   311  func (this *Extension) Encode() error {
   312  	if this.enc == nil {
   313  		var err error
   314  		this.enc, err = encodeExtension(this.desc, this.value)
   315  		if err != nil {
   316  			return err
   317  		}
   318  	}
   319  	return nil
   320  }
   321  
   322  func (this Extension) GoString() string {
   323  	if err := this.Encode(); err != nil {
   324  		return fmt.Sprintf("error encoding extension: %v", err)
   325  	}
   326  	return fmt.Sprintf("proto.NewExtension(%#v)", this.enc)
   327  }
   328  
   329  func SetUnsafeExtension(pb Message, fieldNum int32, value interface{}) error {
   330  	typ := reflect.TypeOf(pb).Elem()
   331  	ext, ok := extensionMaps[typ]
   332  	if !ok {
   333  		return fmt.Errorf("proto: bad extended type; %s is not extendable", typ.String())
   334  	}
   335  	desc, ok := ext[fieldNum]
   336  	if !ok {
   337  		return errors.New("proto: bad extension number; not in declared ranges")
   338  	}
   339  	return SetExtension(pb, desc, value)
   340  }
   341  
   342  func GetUnsafeExtension(pb Message, fieldNum int32) (interface{}, error) {
   343  	typ := reflect.TypeOf(pb).Elem()
   344  	ext, ok := extensionMaps[typ]
   345  	if !ok {
   346  		return nil, fmt.Errorf("proto: bad extended type; %s is not extendable", typ.String())
   347  	}
   348  	desc, ok := ext[fieldNum]
   349  	if !ok {
   350  		return nil, fmt.Errorf("unregistered field number %d", fieldNum)
   351  	}
   352  	return GetExtension(pb, desc)
   353  }
   354  
   355  func NewUnsafeXXX_InternalExtensions(m map[int32]Extension) XXX_InternalExtensions {
   356  	x := &XXX_InternalExtensions{
   357  		p: new(struct {
   358  			mu           sync.Mutex
   359  			extensionMap map[int32]Extension
   360  		}),
   361  	}
   362  	x.p.extensionMap = m
   363  	return *x
   364  }
   365  
   366  func GetUnsafeExtensionsMap(extendable Message) map[int32]Extension {
   367  	pb := extendable.(extendableProto)
   368  	return pb.extensionsWrite()
   369  }
   370  
   371  func deleteExtension(pb extensionsBytes, theFieldNum int32, offset int) int {
   372  	ext := pb.GetExtensions()
   373  	for offset < len(*ext) {
   374  		tag, n1 := DecodeVarint((*ext)[offset:])
   375  		fieldNum := int32(tag >> 3)
   376  		wireType := int(tag & 0x7)
   377  		n2, err := size((*ext)[offset+n1:], wireType)
   378  		if err != nil {
   379  			panic(err)
   380  		}
   381  		newOffset := offset + n1 + n2
   382  		if fieldNum == theFieldNum {
   383  			*ext = append((*ext)[:offset], (*ext)[newOffset:]...)
   384  			return offset
   385  		}
   386  		offset = newOffset
   387  	}
   388  	return -1
   389  }
   390  

View as plain text