1
2
3
4
5
6
7 package unified
8
9 import (
10 "bytes"
11 "context"
12 "encoding/hex"
13 "fmt"
14 "strings"
15
16 "go.mongodb.org/mongo-driver/bson"
17 "go.mongodb.org/mongo-driver/bson/bsontype"
18 )
19
20
21
22 type keyPathCtxKey struct{}
23
24
25
26
27 type extraKeysAllowedCtxKey struct{}
28 type extraKeysAllowedRootMatchCtxKey struct{}
29
30 func makeMatchContext(ctx context.Context, keyPath string, extraKeysAllowed bool) context.Context {
31 ctx = context.WithValue(ctx, keyPathCtxKey{}, keyPath)
32 ctx = context.WithValue(ctx, extraKeysAllowedCtxKey{}, extraKeysAllowed)
33
34
35 if _, ok := ctx.Value(extraKeysAllowedRootMatchCtxKey{}).(bool); !ok {
36 ctx = context.WithValue(ctx, extraKeysAllowedRootMatchCtxKey{}, extraKeysAllowed)
37 }
38
39 return ctx
40 }
41
42
43
44
45 func verifyValuesMatch(ctx context.Context, expected, actual bson.RawValue, extraKeysAllowed bool) error {
46 return verifyValuesMatchInner(makeMatchContext(ctx, "", extraKeysAllowed), expected, actual)
47 }
48
49 func verifyValuesMatchInner(ctx context.Context, expected, actual bson.RawValue) error {
50 keyPath := ctx.Value(keyPathCtxKey{}).(string)
51 extraKeysAllowed := ctx.Value(extraKeysAllowedCtxKey{}).(bool)
52
53 if expectedDoc, ok := expected.DocumentOK(); ok {
54
55
56
57 if requiresSpecialMatching(expectedDoc) {
58 if err := evaluateSpecialComparison(ctx, expectedDoc, actual); err != nil {
59 return newMatchingError(keyPath, "error doing special matching assertion: %v", err)
60 }
61 return nil
62 }
63
64 actualDoc, ok := actual.DocumentOK()
65 if !ok {
66 return newMatchingError(keyPath, "expected value to be a document but got a %s", actual.Type)
67 }
68
69
70 expectedElems, _ := expectedDoc.Elements()
71 for _, expectedElem := range expectedElems {
72 expectedKey := expectedElem.Key()
73 expectedValue := expectedElem.Value()
74
75 fullKeyPath := expectedKey
76 if keyPath != "" {
77 fullKeyPath = keyPath + "." + expectedKey
78 }
79
80
81
82 actualValue, err := actualDoc.LookupErr(expectedKey)
83 if specialDoc, ok := expectedValue.DocumentOK(); ok && requiresSpecialMatching(specialDoc) {
84
85
86
87
88 ctx = makeMatchContext(ctx, "", false)
89 if err := evaluateSpecialComparison(ctx, specialDoc, actualValue); err != nil {
90 return newMatchingError(fullKeyPath, "error doing special matching assertion: %v", err)
91 }
92 continue
93 }
94
95
96 if err != nil {
97 return newMatchingError(fullKeyPath, "key not found in actual document")
98 }
99
100
101
102
103
104 mixedTypeEvaluated, err := evaluateMixedTypeComparison(expectedKey, expectedValue, actualValue)
105 if err != nil {
106 return newMatchingError(fullKeyPath, "error doing mixed-type matching assertion: %v", err)
107 }
108 if mixedTypeEvaluated {
109 continue
110 }
111
112
113 comparisonCtx := makeMatchContext(ctx, fullKeyPath, false)
114 if err := verifyValuesMatchInner(comparisonCtx, expectedValue, actualValue); err != nil {
115 return err
116 }
117 }
118
119
120
121
122 if !extraKeysAllowed {
123 actualElems, _ := actualDoc.Elements()
124 for _, actualElem := range actualElems {
125 if _, err := expectedDoc.LookupErr(actualElem.Key()); err != nil {
126 return newMatchingError(keyPath, "extra key %q found in actual document %s", actualElem.Key(),
127 actualDoc)
128 }
129 }
130 }
131
132 return nil
133 }
134 if expectedArr, ok := expected.ArrayOK(); ok {
135 actualArr, ok := actual.ArrayOK()
136 if !ok {
137 return newMatchingError(keyPath, "expected value to be an array but got a %s", actual.Type)
138 }
139
140 expectedValues, _ := expectedArr.Values()
141 actualValues, _ := actualArr.Values()
142
143
144 if len(expectedValues) != len(actualValues) {
145 return newMatchingError(keyPath, "expected array length %d, got %d", len(expectedValues),
146 len(actualValues))
147 }
148
149 for idx, expectedValue := range expectedValues {
150
151 fullKeyPath := fmt.Sprintf("%d", idx)
152 if keyPath != "" {
153 fullKeyPath = keyPath + "." + fullKeyPath
154 }
155
156 comparisonCtx := makeMatchContext(ctx, fullKeyPath, extraKeysAllowed)
157 err := verifyValuesMatchInner(comparisonCtx, expectedValue, actualValues[idx])
158 if err != nil {
159 return err
160 }
161 }
162
163 return nil
164 }
165
166
167
168 if expected.IsNumber() {
169 if !actual.IsNumber() {
170 return newMatchingError(keyPath, "expected value to be a number but got a %s", actual.Type)
171 }
172
173 expectedInt64 := expected.AsInt64()
174 actualInt64 := actual.AsInt64()
175 if expectedInt64 != actualInt64 {
176 return newMatchingError(keyPath, "expected numeric value %d, got %d", expectedInt64, actualInt64)
177 }
178 return nil
179 }
180
181
182 if !expected.Equal(actual) {
183 return newMatchingError(keyPath, "expected value %s, got %s", expected, actual)
184 }
185 return nil
186 }
187
188
189
190 func compareDocumentToString(expected, actual bson.RawValue) error {
191 expectedDocument, ok := expected.DocumentOK()
192 if !ok {
193 return fmt.Errorf("expected value to be a document but got a %s", expected.Type)
194 }
195
196 actualString, ok := actual.StringValueOK()
197 if !ok {
198 return fmt.Errorf("expected value to be a string but got a %s", actual.Type)
199 }
200
201 if actualString != expectedDocument.String() {
202 return fmt.Errorf("expected value %s, got %s", expectedDocument.String(), actualString)
203 }
204 return nil
205 }
206
207
208
209 func evaluateMixedTypeComparison(expectedKey string, expected, actual bson.RawValue) (bool, error) {
210 switch expectedKey {
211 case "comment":
212 if expected.Type == bsontype.EmbeddedDocument && actual.Type == bsontype.String {
213 return true, compareDocumentToString(expected, actual)
214 }
215 }
216 return false, nil
217 }
218
219 func evaluateSpecialComparison(ctx context.Context, assertionDoc bson.Raw, actual bson.RawValue) error {
220 assertionElem := assertionDoc.Index(0)
221 assertion := assertionElem.Key()
222 assertionVal := assertionElem.Value()
223 extraKeysAllowed := ctx.Value(extraKeysAllowedCtxKey{}).(bool)
224 extraKeysRootMatchAllowed := ctx.Value(extraKeysAllowedRootMatchCtxKey{}).(bool)
225
226 switch assertion {
227 case "$$exists":
228 shouldExist := assertionVal.Boolean()
229 exists := actual.Validate() == nil
230 if shouldExist != exists {
231 return fmt.Errorf("expected value to exist: %v; value actually exists: %v", shouldExist, exists)
232 }
233 case "$$type":
234 possibleTypes, err := getTypesArray(assertionVal)
235 if err != nil {
236 return fmt.Errorf("error getting possible types for a $$type assertion: %v", err)
237 }
238
239 for _, possibleType := range possibleTypes {
240 if actual.Type == possibleType {
241 return nil
242 }
243 }
244 return fmt.Errorf("expected type to be one of %v but was %s", possibleTypes, actual.Type)
245 case "$$matchesEntity":
246 expected, err := entities(ctx).BSONValue(assertionVal.StringValue())
247 if err != nil {
248 return err
249 }
250
251
252 return verifyValuesMatchInner(ctx, expected, actual)
253 case "$$matchesHexBytes":
254 expectedBytes, err := hex.DecodeString(assertionVal.StringValue())
255 if err != nil {
256 return fmt.Errorf("error converting $$matcesHexBytes value to bytes: %v", err)
257 }
258
259 _, actualBytes, ok := actual.BinaryOK()
260 if !ok {
261 return fmt.Errorf("expected binary value for a $$matchesHexBytes assertion, but got a %s", actual.Type)
262 }
263 if !bytes.Equal(expectedBytes, actualBytes) {
264 return fmt.Errorf("expected bytes %v, got %v", expectedBytes, actualBytes)
265 }
266 case "$$unsetOrMatches":
267 if actual.Validate() != nil {
268 return nil
269 }
270
271
272
273 return verifyValuesMatchInner(ctx, assertionVal, actual)
274 case "$$sessionLsid":
275 sess, err := entities(ctx).session(assertionVal.StringValue())
276 if err != nil {
277 return err
278 }
279
280 expectedID := sess.ID()
281 actualID, ok := actual.DocumentOK()
282 if !ok {
283 return fmt.Errorf("expected document value for a $$sessionLsid assertion, but got a %s", actual.Type)
284 }
285 if !bytes.Equal(expectedID, actualID) {
286 return fmt.Errorf("expected lsid %v, got %v", expectedID, actualID)
287 }
288 case "$$lte":
289 if assertionVal.Type != bsontype.Int32 && assertionVal.Type != bsontype.Int64 {
290 return fmt.Errorf("expected assertionVal to be an Int32 or Int64 but got a %s", assertionVal.Type)
291 }
292 if actual.Type != bsontype.Int32 && actual.Type != bsontype.Int64 {
293 return fmt.Errorf("expected value to be an Int32 or Int64 but got a %s", actual.Type)
294 }
295
296
297
298 expectedInt64 := assertionVal.AsInt64()
299 actualInt64 := actual.AsInt64()
300 if actualInt64 > expectedInt64 {
301 return fmt.Errorf("expected numeric value %d to be less than or equal %d", actualInt64, expectedInt64)
302 }
303 return nil
304 case "$$matchAsDocument":
305 var actualDoc bson.Raw
306 str, ok := actual.StringValueOK()
307 if !ok {
308 return fmt.Errorf("expected value to be a string but got a %s", actual.Type)
309 }
310
311 if err := bson.UnmarshalExtJSON([]byte(str), true, &actualDoc); err != nil {
312 return fmt.Errorf("error unmarshalling string as document: %v", err)
313 }
314
315 if err := verifyValuesMatch(ctx, assertionVal, documentToRawValue(actualDoc), extraKeysAllowed); err != nil {
316 return fmt.Errorf("error matching $$matchAsRoot assertion: %v", err)
317 }
318 case "$$matchAsRoot":
319
320
321 if err := verifyValuesMatch(ctx, assertionVal, actual, extraKeysRootMatchAllowed); err != nil {
322 return fmt.Errorf("error matching $$matchAsRoot assertion: %v", err)
323 }
324 default:
325 return fmt.Errorf("unrecognized special matching assertion %q", assertion)
326 }
327
328 return nil
329 }
330
331 func requiresSpecialMatching(doc bson.Raw) bool {
332 elems, _ := doc.Elements()
333 return len(elems) == 1 && strings.HasPrefix(elems[0].Key(), "$$")
334 }
335
336 func getTypesArray(val bson.RawValue) ([]bsontype.Type, error) {
337 switch val.Type {
338 case bsontype.String:
339 convertedType, err := convertStringToBSONType(val.StringValue())
340 if err != nil {
341 return nil, err
342 }
343
344 return []bsontype.Type{convertedType}, nil
345 case bsontype.Array:
346 var typeStrings []string
347 if err := val.Unmarshal(&typeStrings); err != nil {
348 return nil, fmt.Errorf("error unmarshalling to slice of strings: %v", err)
349 }
350
351 var types []bsontype.Type
352 for _, typeStr := range typeStrings {
353 convertedType, err := convertStringToBSONType(typeStr)
354 if err != nil {
355 return nil, err
356 }
357
358 types = append(types, convertedType)
359 }
360 return types, nil
361 default:
362 return nil, fmt.Errorf("invalid type to convert to bsontype.Type slice: %s", val.Type)
363 }
364 }
365
366 func convertStringToBSONType(typeStr string) (bsontype.Type, error) {
367 switch typeStr {
368 case "double":
369 return bsontype.Double, nil
370 case "string":
371 return bsontype.String, nil
372 case "object":
373 return bsontype.EmbeddedDocument, nil
374 case "array":
375 return bsontype.Array, nil
376 case "binData":
377 return bsontype.Binary, nil
378 case "undefined":
379 return bsontype.Undefined, nil
380 case "objectId":
381 return bsontype.ObjectID, nil
382 case "bool":
383 return bsontype.Boolean, nil
384 case "date":
385 return bsontype.DateTime, nil
386 case "null":
387 return bsontype.Null, nil
388 case "regex":
389 return bsontype.Regex, nil
390 case "dbPointer":
391 return bsontype.DBPointer, nil
392 case "javascript":
393 return bsontype.JavaScript, nil
394 case "symbol":
395 return bsontype.Symbol, nil
396 case "javascriptWithScope":
397 return bsontype.CodeWithScope, nil
398 case "int":
399 return bsontype.Int32, nil
400 case "timestamp":
401 return bsontype.Timestamp, nil
402 case "long":
403 return bsontype.Int64, nil
404 case "decimal":
405 return bsontype.Decimal128, nil
406 case "minKey":
407 return bsontype.MinKey, nil
408 case "maxKey":
409 return bsontype.MaxKey, nil
410 default:
411 return bsontype.Type(0), fmt.Errorf("unrecognized BSON type string %q", typeStr)
412 }
413 }
414
415
416
417
418 func newMatchingError(keyPath, msg string, args ...interface{}) error {
419 fullMsg := fmt.Sprintf(msg, args...)
420 if keyPath == "" {
421 return fmt.Errorf("comparison error at top-level: %s", fullMsg)
422 }
423 return fmt.Errorf("comparison error at key %q: %s", keyPath, fullMsg)
424 }
425
View as plain text