...
1
2
3 package gstruct
4
5 import (
6 "errors"
7 "fmt"
8 "reflect"
9 "runtime/debug"
10 "strings"
11
12 "github.com/onsi/gomega/format"
13 errorsutil "github.com/onsi/gomega/gstruct/errors"
14 "github.com/onsi/gomega/types"
15 )
16
17 func MatchAllKeys(keys Keys) types.GomegaMatcher {
18 return &KeysMatcher{
19 Keys: keys,
20 }
21 }
22
23 func MatchKeys(options Options, keys Keys) types.GomegaMatcher {
24 return &KeysMatcher{
25 Keys: keys,
26 IgnoreExtras: options&IgnoreExtras != 0,
27 IgnoreMissing: options&IgnoreMissing != 0,
28 }
29 }
30
31 type KeysMatcher struct {
32
33 Keys Keys
34
35
36 IgnoreExtras bool
37
38 IgnoreMissing bool
39
40
41 failures []error
42 }
43
44 type Keys map[interface{}]types.GomegaMatcher
45
46 func (m *KeysMatcher) Match(actual interface{}) (success bool, err error) {
47 if reflect.TypeOf(actual).Kind() != reflect.Map {
48 return false, fmt.Errorf("%v is type %T, expected map", actual, actual)
49 }
50
51 m.failures = m.matchKeys(actual)
52 if len(m.failures) > 0 {
53 return false, nil
54 }
55 return true, nil
56 }
57
58 func (m *KeysMatcher) matchKeys(actual interface{}) (errs []error) {
59 actualValue := reflect.ValueOf(actual)
60 keys := map[interface{}]bool{}
61 for _, keyValue := range actualValue.MapKeys() {
62 key := keyValue.Interface()
63 keys[key] = true
64
65 err := func() (err error) {
66
67
68 defer func() {
69 if r := recover(); r != nil {
70 err = fmt.Errorf("panic checking %+v: %v\n%s", actual, r, debug.Stack())
71 }
72 }()
73
74 matcher, ok := m.Keys[key]
75 if !ok {
76 if !m.IgnoreExtras {
77 return fmt.Errorf("unexpected key %s: %+v", key, actual)
78 }
79 return nil
80 }
81
82 valInterface := actualValue.MapIndex(keyValue).Interface()
83
84 match, err := matcher.Match(valInterface)
85 if err != nil {
86 return err
87 }
88
89 if !match {
90 if nesting, ok := matcher.(errorsutil.NestingMatcher); ok {
91 return errorsutil.AggregateError(nesting.Failures())
92 }
93 return errors.New(matcher.FailureMessage(valInterface))
94 }
95 return nil
96 }()
97 if err != nil {
98 errs = append(errs, errorsutil.Nest(fmt.Sprintf(".%#v", key), err))
99 }
100 }
101
102 for key := range m.Keys {
103 if !keys[key] && !m.IgnoreMissing {
104 errs = append(errs, fmt.Errorf("missing expected key %s", key))
105 }
106 }
107
108 return errs
109 }
110
111 func (m *KeysMatcher) FailureMessage(actual interface{}) (message string) {
112 failures := make([]string, len(m.failures))
113 for i := range m.failures {
114 failures[i] = m.failures[i].Error()
115 }
116 return format.Message(reflect.TypeOf(actual).Name(),
117 fmt.Sprintf("to match keys: {\n%v\n}\n", strings.Join(failures, "\n")))
118 }
119
120 func (m *KeysMatcher) NegatedFailureMessage(actual interface{}) (message string) {
121 return format.Message(actual, "not to match keys")
122 }
123
124 func (m *KeysMatcher) Failures() []error {
125 return m.failures
126 }
127
View as plain text