...
1
2
3 package gstruct
4
5 import (
6 "errors"
7 "fmt"
8 "reflect"
9 "runtime/debug"
10 "strings"
11
12 "github.com/onsi/gomega/format"
13 errorsutil "github.com/onsi/gomega/gstruct/errors"
14 "github.com/onsi/gomega/types"
15 )
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 func MatchAllFields(fields Fields) types.GomegaMatcher {
35 return &FieldsMatcher{
36 Fields: fields,
37 }
38 }
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62 func MatchFields(options Options, fields Fields) types.GomegaMatcher {
63 return &FieldsMatcher{
64 Fields: fields,
65 IgnoreExtras: options&IgnoreExtras != 0,
66 IgnoreMissing: options&IgnoreMissing != 0,
67 }
68 }
69
70 type FieldsMatcher struct {
71
72 Fields Fields
73
74
75 IgnoreExtras bool
76
77 IgnoreMissing bool
78
79
80 failures []error
81 }
82
83
84 type Fields map[string]types.GomegaMatcher
85
86 func (m *FieldsMatcher) Match(actual interface{}) (success bool, err error) {
87 if reflect.TypeOf(actual).Kind() != reflect.Struct {
88 return false, fmt.Errorf("%v is type %T, expected struct", actual, actual)
89 }
90
91 m.failures = m.matchFields(actual)
92 if len(m.failures) > 0 {
93 return false, nil
94 }
95 return true, nil
96 }
97
98 func (m *FieldsMatcher) matchFields(actual interface{}) (errs []error) {
99 val := reflect.ValueOf(actual)
100 typ := val.Type()
101 fields := map[string]bool{}
102 for i := 0; i < val.NumField(); i++ {
103 fieldName := typ.Field(i).Name
104 fields[fieldName] = true
105
106 err := func() (err error) {
107
108
109 defer func() {
110 if r := recover(); r != nil {
111 err = fmt.Errorf("panic checking %+v: %v\n%s", actual, r, debug.Stack())
112 }
113 }()
114
115 matcher, expected := m.Fields[fieldName]
116 if !expected {
117 if !m.IgnoreExtras {
118 return fmt.Errorf("unexpected field %s: %+v", fieldName, actual)
119 }
120 return nil
121 }
122
123 field := val.Field(i).Interface()
124
125 match, err := matcher.Match(field)
126 if err != nil {
127 return err
128 } else if !match {
129 if nesting, ok := matcher.(errorsutil.NestingMatcher); ok {
130 return errorsutil.AggregateError(nesting.Failures())
131 }
132 return errors.New(matcher.FailureMessage(field))
133 }
134 return nil
135 }()
136 if err != nil {
137 errs = append(errs, errorsutil.Nest("."+fieldName, err))
138 }
139 }
140
141 for field := range m.Fields {
142 if !fields[field] && !m.IgnoreMissing {
143 errs = append(errs, fmt.Errorf("missing expected field %s", field))
144 }
145 }
146
147 return errs
148 }
149
150 func (m *FieldsMatcher) FailureMessage(actual interface{}) (message string) {
151 failures := make([]string, len(m.failures))
152 for i := range m.failures {
153 failures[i] = m.failures[i].Error()
154 }
155 return format.Message(reflect.TypeOf(actual).Name(),
156 fmt.Sprintf("to match fields: {\n%v\n}\n", strings.Join(failures, "\n")))
157 }
158
159 func (m *FieldsMatcher) NegatedFailureMessage(actual interface{}) (message string) {
160 return format.Message(actual, "not to match fields")
161 }
162
163 func (m *FieldsMatcher) Failures() []error {
164 return m.failures
165 }
166
View as plain text