...

Source file src/github.com/gogo/protobuf/proto/discard.go

Documentation: github.com/gogo/protobuf/proto

     1  // Go support for Protocol Buffers - Google's data interchange format
     2  //
     3  // Copyright 2017 The Go Authors.  All rights reserved.
     4  // https://github.com/golang/protobuf
     5  //
     6  // Redistribution and use in source and binary forms, with or without
     7  // modification, are permitted provided that the following conditions are
     8  // met:
     9  //
    10  //     * Redistributions of source code must retain the above copyright
    11  // notice, this list of conditions and the following disclaimer.
    12  //     * Redistributions in binary form must reproduce the above
    13  // copyright notice, this list of conditions and the following disclaimer
    14  // in the documentation and/or other materials provided with the
    15  // distribution.
    16  //     * Neither the name of Google Inc. nor the names of its
    17  // contributors may be used to endorse or promote products derived from
    18  // this software without specific prior written permission.
    19  //
    20  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    21  // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    22  // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    23  // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
    24  // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    25  // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    26  // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
    27  // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
    28  // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    29  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    30  // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    31  
    32  package proto
    33  
    34  import (
    35  	"fmt"
    36  	"reflect"
    37  	"strings"
    38  	"sync"
    39  	"sync/atomic"
    40  )
    41  
    42  type generatedDiscarder interface {
    43  	XXX_DiscardUnknown()
    44  }
    45  
    46  // DiscardUnknown recursively discards all unknown fields from this message
    47  // and all embedded messages.
    48  //
    49  // When unmarshaling a message with unrecognized fields, the tags and values
    50  // of such fields are preserved in the Message. This allows a later call to
    51  // marshal to be able to produce a message that continues to have those
    52  // unrecognized fields. To avoid this, DiscardUnknown is used to
    53  // explicitly clear the unknown fields after unmarshaling.
    54  //
    55  // For proto2 messages, the unknown fields of message extensions are only
    56  // discarded from messages that have been accessed via GetExtension.
    57  func DiscardUnknown(m Message) {
    58  	if m, ok := m.(generatedDiscarder); ok {
    59  		m.XXX_DiscardUnknown()
    60  		return
    61  	}
    62  	// TODO: Dynamically populate a InternalMessageInfo for legacy messages,
    63  	// but the master branch has no implementation for InternalMessageInfo,
    64  	// so it would be more work to replicate that approach.
    65  	discardLegacy(m)
    66  }
    67  
    68  // DiscardUnknown recursively discards all unknown fields.
    69  func (a *InternalMessageInfo) DiscardUnknown(m Message) {
    70  	di := atomicLoadDiscardInfo(&a.discard)
    71  	if di == nil {
    72  		di = getDiscardInfo(reflect.TypeOf(m).Elem())
    73  		atomicStoreDiscardInfo(&a.discard, di)
    74  	}
    75  	di.discard(toPointer(&m))
    76  }
    77  
    78  type discardInfo struct {
    79  	typ reflect.Type
    80  
    81  	initialized int32 // 0: only typ is valid, 1: everything is valid
    82  	lock        sync.Mutex
    83  
    84  	fields       []discardFieldInfo
    85  	unrecognized field
    86  }
    87  
    88  type discardFieldInfo struct {
    89  	field   field // Offset of field, guaranteed to be valid
    90  	discard func(src pointer)
    91  }
    92  
    93  var (
    94  	discardInfoMap  = map[reflect.Type]*discardInfo{}
    95  	discardInfoLock sync.Mutex
    96  )
    97  
    98  func getDiscardInfo(t reflect.Type) *discardInfo {
    99  	discardInfoLock.Lock()
   100  	defer discardInfoLock.Unlock()
   101  	di := discardInfoMap[t]
   102  	if di == nil {
   103  		di = &discardInfo{typ: t}
   104  		discardInfoMap[t] = di
   105  	}
   106  	return di
   107  }
   108  
   109  func (di *discardInfo) discard(src pointer) {
   110  	if src.isNil() {
   111  		return // Nothing to do.
   112  	}
   113  
   114  	if atomic.LoadInt32(&di.initialized) == 0 {
   115  		di.computeDiscardInfo()
   116  	}
   117  
   118  	for _, fi := range di.fields {
   119  		sfp := src.offset(fi.field)
   120  		fi.discard(sfp)
   121  	}
   122  
   123  	// For proto2 messages, only discard unknown fields in message extensions
   124  	// that have been accessed via GetExtension.
   125  	if em, err := extendable(src.asPointerTo(di.typ).Interface()); err == nil {
   126  		// Ignore lock since DiscardUnknown is not concurrency safe.
   127  		emm, _ := em.extensionsRead()
   128  		for _, mx := range emm {
   129  			if m, ok := mx.value.(Message); ok {
   130  				DiscardUnknown(m)
   131  			}
   132  		}
   133  	}
   134  
   135  	if di.unrecognized.IsValid() {
   136  		*src.offset(di.unrecognized).toBytes() = nil
   137  	}
   138  }
   139  
   140  func (di *discardInfo) computeDiscardInfo() {
   141  	di.lock.Lock()
   142  	defer di.lock.Unlock()
   143  	if di.initialized != 0 {
   144  		return
   145  	}
   146  	t := di.typ
   147  	n := t.NumField()
   148  
   149  	for i := 0; i < n; i++ {
   150  		f := t.Field(i)
   151  		if strings.HasPrefix(f.Name, "XXX_") {
   152  			continue
   153  		}
   154  
   155  		dfi := discardFieldInfo{field: toField(&f)}
   156  		tf := f.Type
   157  
   158  		// Unwrap tf to get its most basic type.
   159  		var isPointer, isSlice bool
   160  		if tf.Kind() == reflect.Slice && tf.Elem().Kind() != reflect.Uint8 {
   161  			isSlice = true
   162  			tf = tf.Elem()
   163  		}
   164  		if tf.Kind() == reflect.Ptr {
   165  			isPointer = true
   166  			tf = tf.Elem()
   167  		}
   168  		if isPointer && isSlice && tf.Kind() != reflect.Struct {
   169  			panic(fmt.Sprintf("%v.%s cannot be a slice of pointers to primitive types", t, f.Name))
   170  		}
   171  
   172  		switch tf.Kind() {
   173  		case reflect.Struct:
   174  			switch {
   175  			case !isPointer:
   176  				panic(fmt.Sprintf("%v.%s cannot be a direct struct value", t, f.Name))
   177  			case isSlice: // E.g., []*pb.T
   178  				discardInfo := getDiscardInfo(tf)
   179  				dfi.discard = func(src pointer) {
   180  					sps := src.getPointerSlice()
   181  					for _, sp := range sps {
   182  						if !sp.isNil() {
   183  							discardInfo.discard(sp)
   184  						}
   185  					}
   186  				}
   187  			default: // E.g., *pb.T
   188  				discardInfo := getDiscardInfo(tf)
   189  				dfi.discard = func(src pointer) {
   190  					sp := src.getPointer()
   191  					if !sp.isNil() {
   192  						discardInfo.discard(sp)
   193  					}
   194  				}
   195  			}
   196  		case reflect.Map:
   197  			switch {
   198  			case isPointer || isSlice:
   199  				panic(fmt.Sprintf("%v.%s cannot be a pointer to a map or a slice of map values", t, f.Name))
   200  			default: // E.g., map[K]V
   201  				if tf.Elem().Kind() == reflect.Ptr { // Proto struct (e.g., *T)
   202  					dfi.discard = func(src pointer) {
   203  						sm := src.asPointerTo(tf).Elem()
   204  						if sm.Len() == 0 {
   205  							return
   206  						}
   207  						for _, key := range sm.MapKeys() {
   208  							val := sm.MapIndex(key)
   209  							DiscardUnknown(val.Interface().(Message))
   210  						}
   211  					}
   212  				} else {
   213  					dfi.discard = func(pointer) {} // Noop
   214  				}
   215  			}
   216  		case reflect.Interface:
   217  			// Must be oneof field.
   218  			switch {
   219  			case isPointer || isSlice:
   220  				panic(fmt.Sprintf("%v.%s cannot be a pointer to a interface or a slice of interface values", t, f.Name))
   221  			default: // E.g., interface{}
   222  				// TODO: Make this faster?
   223  				dfi.discard = func(src pointer) {
   224  					su := src.asPointerTo(tf).Elem()
   225  					if !su.IsNil() {
   226  						sv := su.Elem().Elem().Field(0)
   227  						if sv.Kind() == reflect.Ptr && sv.IsNil() {
   228  							return
   229  						}
   230  						switch sv.Type().Kind() {
   231  						case reflect.Ptr: // Proto struct (e.g., *T)
   232  							DiscardUnknown(sv.Interface().(Message))
   233  						}
   234  					}
   235  				}
   236  			}
   237  		default:
   238  			continue
   239  		}
   240  		di.fields = append(di.fields, dfi)
   241  	}
   242  
   243  	di.unrecognized = invalidField
   244  	if f, ok := t.FieldByName("XXX_unrecognized"); ok {
   245  		if f.Type != reflect.TypeOf([]byte{}) {
   246  			panic("expected XXX_unrecognized to be of type []byte")
   247  		}
   248  		di.unrecognized = toField(&f)
   249  	}
   250  
   251  	atomic.StoreInt32(&di.initialized, 1)
   252  }
   253  
   254  func discardLegacy(m Message) {
   255  	v := reflect.ValueOf(m)
   256  	if v.Kind() != reflect.Ptr || v.IsNil() {
   257  		return
   258  	}
   259  	v = v.Elem()
   260  	if v.Kind() != reflect.Struct {
   261  		return
   262  	}
   263  	t := v.Type()
   264  
   265  	for i := 0; i < v.NumField(); i++ {
   266  		f := t.Field(i)
   267  		if strings.HasPrefix(f.Name, "XXX_") {
   268  			continue
   269  		}
   270  		vf := v.Field(i)
   271  		tf := f.Type
   272  
   273  		// Unwrap tf to get its most basic type.
   274  		var isPointer, isSlice bool
   275  		if tf.Kind() == reflect.Slice && tf.Elem().Kind() != reflect.Uint8 {
   276  			isSlice = true
   277  			tf = tf.Elem()
   278  		}
   279  		if tf.Kind() == reflect.Ptr {
   280  			isPointer = true
   281  			tf = tf.Elem()
   282  		}
   283  		if isPointer && isSlice && tf.Kind() != reflect.Struct {
   284  			panic(fmt.Sprintf("%T.%s cannot be a slice of pointers to primitive types", m, f.Name))
   285  		}
   286  
   287  		switch tf.Kind() {
   288  		case reflect.Struct:
   289  			switch {
   290  			case !isPointer:
   291  				panic(fmt.Sprintf("%T.%s cannot be a direct struct value", m, f.Name))
   292  			case isSlice: // E.g., []*pb.T
   293  				for j := 0; j < vf.Len(); j++ {
   294  					discardLegacy(vf.Index(j).Interface().(Message))
   295  				}
   296  			default: // E.g., *pb.T
   297  				discardLegacy(vf.Interface().(Message))
   298  			}
   299  		case reflect.Map:
   300  			switch {
   301  			case isPointer || isSlice:
   302  				panic(fmt.Sprintf("%T.%s cannot be a pointer to a map or a slice of map values", m, f.Name))
   303  			default: // E.g., map[K]V
   304  				tv := vf.Type().Elem()
   305  				if tv.Kind() == reflect.Ptr && tv.Implements(protoMessageType) { // Proto struct (e.g., *T)
   306  					for _, key := range vf.MapKeys() {
   307  						val := vf.MapIndex(key)
   308  						discardLegacy(val.Interface().(Message))
   309  					}
   310  				}
   311  			}
   312  		case reflect.Interface:
   313  			// Must be oneof field.
   314  			switch {
   315  			case isPointer || isSlice:
   316  				panic(fmt.Sprintf("%T.%s cannot be a pointer to a interface or a slice of interface values", m, f.Name))
   317  			default: // E.g., test_proto.isCommunique_Union interface
   318  				if !vf.IsNil() && f.Tag.Get("protobuf_oneof") != "" {
   319  					vf = vf.Elem() // E.g., *test_proto.Communique_Msg
   320  					if !vf.IsNil() {
   321  						vf = vf.Elem()   // E.g., test_proto.Communique_Msg
   322  						vf = vf.Field(0) // E.g., Proto struct (e.g., *T) or primitive value
   323  						if vf.Kind() == reflect.Ptr {
   324  							discardLegacy(vf.Interface().(Message))
   325  						}
   326  					}
   327  				}
   328  			}
   329  		}
   330  	}
   331  
   332  	if vf := v.FieldByName("XXX_unrecognized"); vf.IsValid() {
   333  		if vf.Type() != reflect.TypeOf([]byte{}) {
   334  			panic("expected XXX_unrecognized to be of type []byte")
   335  		}
   336  		vf.Set(reflect.ValueOf([]byte(nil)))
   337  	}
   338  
   339  	// For proto2 messages, only discard unknown fields in message extensions
   340  	// that have been accessed via GetExtension.
   341  	if em, err := extendable(m); err == nil {
   342  		// Ignore lock since discardLegacy is not concurrency safe.
   343  		emm, _ := em.extensionsRead()
   344  		for _, mx := range emm {
   345  			if m, ok := mx.value.(Message); ok {
   346  				discardLegacy(m)
   347  			}
   348  		}
   349  	}
   350  }
   351  

View as plain text