...

Source file src/github.com/google/go-cmp/cmp/cmpopts/struct_filter.go

Documentation: github.com/google/go-cmp/cmp/cmpopts

     1  // Copyright 2017, The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package cmpopts
     6  
     7  import (
     8  	"fmt"
     9  	"reflect"
    10  	"strings"
    11  
    12  	"github.com/google/go-cmp/cmp"
    13  )
    14  
    15  // filterField returns a new Option where opt is only evaluated on paths that
    16  // include a specific exported field on a single struct type.
    17  // The struct type is specified by passing in a value of that type.
    18  //
    19  // The name may be a dot-delimited string (e.g., "Foo.Bar") to select a
    20  // specific sub-field that is embedded or nested within the parent struct.
    21  func filterField(typ interface{}, name string, opt cmp.Option) cmp.Option {
    22  	// TODO: This is currently unexported over concerns of how helper filters
    23  	// can be composed together easily.
    24  	// TODO: Add tests for FilterField.
    25  
    26  	sf := newStructFilter(typ, name)
    27  	return cmp.FilterPath(sf.filter, opt)
    28  }
    29  
    30  type structFilter struct {
    31  	t  reflect.Type // The root struct type to match on
    32  	ft fieldTree    // Tree of fields to match on
    33  }
    34  
    35  func newStructFilter(typ interface{}, names ...string) structFilter {
    36  	// TODO: Perhaps allow * as a special identifier to allow ignoring any
    37  	// number of path steps until the next field match?
    38  	// This could be useful when a concrete struct gets transformed into
    39  	// an anonymous struct where it is not possible to specify that by type,
    40  	// but the transformer happens to provide guarantees about the names of
    41  	// the transformed fields.
    42  
    43  	t := reflect.TypeOf(typ)
    44  	if t == nil || t.Kind() != reflect.Struct {
    45  		panic(fmt.Sprintf("%T must be a non-pointer struct", typ))
    46  	}
    47  	var ft fieldTree
    48  	for _, name := range names {
    49  		cname, err := canonicalName(t, name)
    50  		if err != nil {
    51  			panic(fmt.Sprintf("%s: %v", strings.Join(cname, "."), err))
    52  		}
    53  		ft.insert(cname)
    54  	}
    55  	return structFilter{t, ft}
    56  }
    57  
    58  func (sf structFilter) filter(p cmp.Path) bool {
    59  	for i, ps := range p {
    60  		if ps.Type().AssignableTo(sf.t) && sf.ft.matchPrefix(p[i+1:]) {
    61  			return true
    62  		}
    63  	}
    64  	return false
    65  }
    66  
    67  // fieldTree represents a set of dot-separated identifiers.
    68  //
    69  // For example, inserting the following selectors:
    70  //
    71  //	Foo
    72  //	Foo.Bar.Baz
    73  //	Foo.Buzz
    74  //	Nuka.Cola.Quantum
    75  //
    76  // Results in a tree of the form:
    77  //
    78  //	{sub: {
    79  //		"Foo": {ok: true, sub: {
    80  //			"Bar": {sub: {
    81  //				"Baz": {ok: true},
    82  //			}},
    83  //			"Buzz": {ok: true},
    84  //		}},
    85  //		"Nuka": {sub: {
    86  //			"Cola": {sub: {
    87  //				"Quantum": {ok: true},
    88  //			}},
    89  //		}},
    90  //	}}
    91  type fieldTree struct {
    92  	ok  bool                 // Whether this is a specified node
    93  	sub map[string]fieldTree // The sub-tree of fields under this node
    94  }
    95  
    96  // insert inserts a sequence of field accesses into the tree.
    97  func (ft *fieldTree) insert(cname []string) {
    98  	if ft.sub == nil {
    99  		ft.sub = make(map[string]fieldTree)
   100  	}
   101  	if len(cname) == 0 {
   102  		ft.ok = true
   103  		return
   104  	}
   105  	sub := ft.sub[cname[0]]
   106  	sub.insert(cname[1:])
   107  	ft.sub[cname[0]] = sub
   108  }
   109  
   110  // matchPrefix reports whether any selector in the fieldTree matches
   111  // the start of path p.
   112  func (ft fieldTree) matchPrefix(p cmp.Path) bool {
   113  	for _, ps := range p {
   114  		switch ps := ps.(type) {
   115  		case cmp.StructField:
   116  			ft = ft.sub[ps.Name()]
   117  			if ft.ok {
   118  				return true
   119  			}
   120  			if len(ft.sub) == 0 {
   121  				return false
   122  			}
   123  		case cmp.Indirect:
   124  		default:
   125  			return false
   126  		}
   127  	}
   128  	return false
   129  }
   130  
   131  // canonicalName returns a list of identifiers where any struct field access
   132  // through an embedded field is expanded to include the names of the embedded
   133  // types themselves.
   134  //
   135  // For example, suppose field "Foo" is not directly in the parent struct,
   136  // but actually from an embedded struct of type "Bar". Then, the canonical name
   137  // of "Foo" is actually "Bar.Foo".
   138  //
   139  // Suppose field "Foo" is not directly in the parent struct, but actually
   140  // a field in two different embedded structs of types "Bar" and "Baz".
   141  // Then the selector "Foo" causes a panic since it is ambiguous which one it
   142  // refers to. The user must specify either "Bar.Foo" or "Baz.Foo".
   143  func canonicalName(t reflect.Type, sel string) ([]string, error) {
   144  	var name string
   145  	sel = strings.TrimPrefix(sel, ".")
   146  	if sel == "" {
   147  		return nil, fmt.Errorf("name must not be empty")
   148  	}
   149  	if i := strings.IndexByte(sel, '.'); i < 0 {
   150  		name, sel = sel, ""
   151  	} else {
   152  		name, sel = sel[:i], sel[i:]
   153  	}
   154  
   155  	// Type must be a struct or pointer to struct.
   156  	if t.Kind() == reflect.Ptr {
   157  		t = t.Elem()
   158  	}
   159  	if t.Kind() != reflect.Struct {
   160  		return nil, fmt.Errorf("%v must be a struct", t)
   161  	}
   162  
   163  	// Find the canonical name for this current field name.
   164  	// If the field exists in an embedded struct, then it will be expanded.
   165  	sf, _ := t.FieldByName(name)
   166  	if !isExported(name) {
   167  		// Avoid using reflect.Type.FieldByName for unexported fields due to
   168  		// buggy behavior with regard to embeddeding and unexported fields.
   169  		// See https://golang.org/issue/4876 for details.
   170  		sf = reflect.StructField{}
   171  		for i := 0; i < t.NumField() && sf.Name == ""; i++ {
   172  			if t.Field(i).Name == name {
   173  				sf = t.Field(i)
   174  			}
   175  		}
   176  	}
   177  	if sf.Name == "" {
   178  		return []string{name}, fmt.Errorf("does not exist")
   179  	}
   180  	var ss []string
   181  	for i := range sf.Index {
   182  		ss = append(ss, t.FieldByIndex(sf.Index[:i+1]).Name)
   183  	}
   184  	if sel == "" {
   185  		return ss, nil
   186  	}
   187  	ssPost, err := canonicalName(sf.Type, sel)
   188  	return append(ss, ssPost...), err
   189  }
   190  

View as plain text