...

Source file src/go.mongodb.org/mongo-driver/mongo/integration/unified/matches.go

Documentation: go.mongodb.org/mongo-driver/mongo/integration/unified

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package unified
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"encoding/hex"
    13  	"fmt"
    14  	"strings"
    15  
    16  	"go.mongodb.org/mongo-driver/bson"
    17  	"go.mongodb.org/mongo-driver/bson/bsontype"
    18  )
    19  
    20  // keyPathCtxKey is used as a key for a Context object. The value conveys the BSON key path that is currently being
    21  // compared.
    22  type keyPathCtxKey struct{}
    23  
    24  // extraKeysAllowedCtxKey is used as a key for a Context object. The value conveys whether or not the document under
    25  // test can contain extra keys. For example, if the expected document is {x: 1}, the document {x: 1, y: 1} would match
    26  // if the value for this key is true.
    27  type extraKeysAllowedCtxKey struct{}
    28  type extraKeysAllowedRootMatchCtxKey struct{}
    29  
    30  func makeMatchContext(ctx context.Context, keyPath string, extraKeysAllowed bool) context.Context {
    31  	ctx = context.WithValue(ctx, keyPathCtxKey{}, keyPath)
    32  	ctx = context.WithValue(ctx, extraKeysAllowedCtxKey{}, extraKeysAllowed)
    33  
    34  	// The Root Match Context should be persisted once set.
    35  	if _, ok := ctx.Value(extraKeysAllowedRootMatchCtxKey{}).(bool); !ok {
    36  		ctx = context.WithValue(ctx, extraKeysAllowedRootMatchCtxKey{}, extraKeysAllowed)
    37  	}
    38  
    39  	return ctx
    40  }
    41  
    42  // verifyValuesMatch compares the provided BSON values and returns an error if they do not match. If the values are
    43  // documents and extraKeysAllowed is true, the actual value will be allowed to have additional keys at the top-level.
    44  // For example, an expected document {x: 1} would match the actual document {x: 1, y: 1}.
    45  func verifyValuesMatch(ctx context.Context, expected, actual bson.RawValue, extraKeysAllowed bool) error {
    46  	return verifyValuesMatchInner(makeMatchContext(ctx, "", extraKeysAllowed), expected, actual)
    47  }
    48  
    49  func verifyValuesMatchInner(ctx context.Context, expected, actual bson.RawValue) error {
    50  	keyPath := ctx.Value(keyPathCtxKey{}).(string)
    51  	extraKeysAllowed := ctx.Value(extraKeysAllowedCtxKey{}).(bool)
    52  
    53  	if expectedDoc, ok := expected.DocumentOK(); ok {
    54  		// If the root document only has one element and the key is a special matching operator, the actual value might
    55  		// not actually be a document. In this case, evaluate the special operator with the actual value rather than
    56  		// doing an element-wise document comparison.
    57  		if requiresSpecialMatching(expectedDoc) {
    58  			if err := evaluateSpecialComparison(ctx, expectedDoc, actual); err != nil {
    59  				return newMatchingError(keyPath, "error doing special matching assertion: %v", err)
    60  			}
    61  			return nil
    62  		}
    63  
    64  		actualDoc, ok := actual.DocumentOK()
    65  		if !ok {
    66  			return newMatchingError(keyPath, "expected value to be a document but got a %s", actual.Type)
    67  		}
    68  
    69  		// Perform element-wise comparisons.
    70  		expectedElems, _ := expectedDoc.Elements()
    71  		for _, expectedElem := range expectedElems {
    72  			expectedKey := expectedElem.Key()
    73  			expectedValue := expectedElem.Value()
    74  
    75  			fullKeyPath := expectedKey
    76  			if keyPath != "" {
    77  				fullKeyPath = keyPath + "." + expectedKey
    78  			}
    79  
    80  			// Get the value from actualDoc here but don't check the error until later because some of the special
    81  			// matching operators can assert that the value isn't present in the document (e.g. $$exists).
    82  			actualValue, err := actualDoc.LookupErr(expectedKey)
    83  			if specialDoc, ok := expectedValue.DocumentOK(); ok && requiresSpecialMatching(specialDoc) {
    84  				// Reset the key path so any errors returned from the function will only have the key path for the
    85  				// target value. Also unconditionally set extraKeysAllowed to false because an assertion like
    86  				// $$unsetOrMatches could recurse back into this function. In that case, the target document is nested
    87  				// and should not have extra keys.
    88  				ctx = makeMatchContext(ctx, "", false)
    89  				if err := evaluateSpecialComparison(ctx, specialDoc, actualValue); err != nil {
    90  					return newMatchingError(fullKeyPath, "error doing special matching assertion: %v", err)
    91  				}
    92  				continue
    93  			}
    94  
    95  			// This isn't a special comparison. Assert that the value exists in the actual document.
    96  			if err != nil {
    97  				return newMatchingError(fullKeyPath, "key not found in actual document")
    98  			}
    99  
   100  			// Check to see if the keypath requires us to convert actual/expected to make a true comparison.  If the
   101  			// comparison is not supported for the keypath, continue with the recursive strategy.
   102  			//
   103  			// TODO(GODRIVER-2386): this branch of logic will be removed once we add document support for comments
   104  			mixedTypeEvaluated, err := evaluateMixedTypeComparison(expectedKey, expectedValue, actualValue)
   105  			if err != nil {
   106  				return newMatchingError(fullKeyPath, "error doing mixed-type matching assertion: %v", err)
   107  			}
   108  			if mixedTypeEvaluated {
   109  				continue
   110  			}
   111  
   112  			// Nested documents cannot have extra keys, so we unconditionally pass false for extraKeysAllowed.
   113  			comparisonCtx := makeMatchContext(ctx, fullKeyPath, false)
   114  			if err := verifyValuesMatchInner(comparisonCtx, expectedValue, actualValue); err != nil {
   115  				return err
   116  			}
   117  		}
   118  		// If required, verify that the actual document does not have extra elements. We do this by iterating over the
   119  		// actual and checking for each key in the expected rather than comparing element counts because the presence of
   120  		// special operators can cause incorrect counts. For example, the document {y: {$$exists: false}} has one
   121  		// element, but should match the document {}, which has none.
   122  		if !extraKeysAllowed {
   123  			actualElems, _ := actualDoc.Elements()
   124  			for _, actualElem := range actualElems {
   125  				if _, err := expectedDoc.LookupErr(actualElem.Key()); err != nil {
   126  					return newMatchingError(keyPath, "extra key %q found in actual document %s", actualElem.Key(),
   127  						actualDoc)
   128  				}
   129  			}
   130  		}
   131  
   132  		return nil
   133  	}
   134  	if expectedArr, ok := expected.ArrayOK(); ok {
   135  		actualArr, ok := actual.ArrayOK()
   136  		if !ok {
   137  			return newMatchingError(keyPath, "expected value to be an array but got a %s", actual.Type)
   138  		}
   139  
   140  		expectedValues, _ := expectedArr.Values()
   141  		actualValues, _ := actualArr.Values()
   142  
   143  		// Arrays must always have the same number of elements.
   144  		if len(expectedValues) != len(actualValues) {
   145  			return newMatchingError(keyPath, "expected array length %d, got %d", len(expectedValues),
   146  				len(actualValues))
   147  		}
   148  
   149  		for idx, expectedValue := range expectedValues {
   150  			// Use the index as the key to augment the key path.
   151  			fullKeyPath := fmt.Sprintf("%d", idx)
   152  			if keyPath != "" {
   153  				fullKeyPath = keyPath + "." + fullKeyPath
   154  			}
   155  
   156  			comparisonCtx := makeMatchContext(ctx, fullKeyPath, extraKeysAllowed)
   157  			err := verifyValuesMatchInner(comparisonCtx, expectedValue, actualValues[idx])
   158  			if err != nil {
   159  				return err
   160  			}
   161  		}
   162  
   163  		return nil
   164  	}
   165  
   166  	// Numeric values must be considered equal even if their types are different (e.g. if expected is an int32 and
   167  	// actual is an int64).
   168  	if expected.IsNumber() {
   169  		if !actual.IsNumber() {
   170  			return newMatchingError(keyPath, "expected value to be a number but got a %s", actual.Type)
   171  		}
   172  
   173  		expectedInt64 := expected.AsInt64()
   174  		actualInt64 := actual.AsInt64()
   175  		if expectedInt64 != actualInt64 {
   176  			return newMatchingError(keyPath, "expected numeric value %d, got %d", expectedInt64, actualInt64)
   177  		}
   178  		return nil
   179  	}
   180  
   181  	// If expected is not a recursive or numeric type, we can directly call Equal to do the comparison.
   182  	if !expected.Equal(actual) {
   183  		return newMatchingError(keyPath, "expected value %s, got %s", expected, actual)
   184  	}
   185  	return nil
   186  }
   187  
   188  // compareDocumentToString will compare an expected document to an actual string by converting the document into a
   189  // string.
   190  func compareDocumentToString(expected, actual bson.RawValue) error {
   191  	expectedDocument, ok := expected.DocumentOK()
   192  	if !ok {
   193  		return fmt.Errorf("expected value to be a document but got a %s", expected.Type)
   194  	}
   195  
   196  	actualString, ok := actual.StringValueOK()
   197  	if !ok {
   198  		return fmt.Errorf("expected value to be a string but got a %s", actual.Type)
   199  	}
   200  
   201  	if actualString != expectedDocument.String() {
   202  		return fmt.Errorf("expected value %s, got %s", expectedDocument.String(), actualString)
   203  	}
   204  	return nil
   205  }
   206  
   207  // evaluateMixedTypeComparison compares an expected document with an actual string.  If this comparison occurs, then
   208  // the function will return `true` along with any resulting error.
   209  func evaluateMixedTypeComparison(expectedKey string, expected, actual bson.RawValue) (bool, error) {
   210  	switch expectedKey {
   211  	case "comment":
   212  		if expected.Type == bsontype.EmbeddedDocument && actual.Type == bsontype.String {
   213  			return true, compareDocumentToString(expected, actual)
   214  		}
   215  	}
   216  	return false, nil
   217  }
   218  
   219  func evaluateSpecialComparison(ctx context.Context, assertionDoc bson.Raw, actual bson.RawValue) error {
   220  	assertionElem := assertionDoc.Index(0)
   221  	assertion := assertionElem.Key()
   222  	assertionVal := assertionElem.Value()
   223  	extraKeysAllowed := ctx.Value(extraKeysAllowedCtxKey{}).(bool)
   224  	extraKeysRootMatchAllowed := ctx.Value(extraKeysAllowedRootMatchCtxKey{}).(bool)
   225  
   226  	switch assertion {
   227  	case "$$exists":
   228  		shouldExist := assertionVal.Boolean()
   229  		exists := actual.Validate() == nil
   230  		if shouldExist != exists {
   231  			return fmt.Errorf("expected value to exist: %v; value actually exists: %v", shouldExist, exists)
   232  		}
   233  	case "$$type":
   234  		possibleTypes, err := getTypesArray(assertionVal)
   235  		if err != nil {
   236  			return fmt.Errorf("error getting possible types for a $$type assertion: %v", err)
   237  		}
   238  
   239  		for _, possibleType := range possibleTypes {
   240  			if actual.Type == possibleType {
   241  				return nil
   242  			}
   243  		}
   244  		return fmt.Errorf("expected type to be one of %v but was %s", possibleTypes, actual.Type)
   245  	case "$$matchesEntity":
   246  		expected, err := entities(ctx).BSONValue(assertionVal.StringValue())
   247  		if err != nil {
   248  			return err
   249  		}
   250  
   251  		// $$matchesEntity doesn't modify the nesting level of the key path so we can propagate ctx without changes.
   252  		return verifyValuesMatchInner(ctx, expected, actual)
   253  	case "$$matchesHexBytes":
   254  		expectedBytes, err := hex.DecodeString(assertionVal.StringValue())
   255  		if err != nil {
   256  			return fmt.Errorf("error converting $$matcesHexBytes value to bytes: %v", err)
   257  		}
   258  
   259  		_, actualBytes, ok := actual.BinaryOK()
   260  		if !ok {
   261  			return fmt.Errorf("expected binary value for a $$matchesHexBytes assertion, but got a %s", actual.Type)
   262  		}
   263  		if !bytes.Equal(expectedBytes, actualBytes) {
   264  			return fmt.Errorf("expected bytes %v, got %v", expectedBytes, actualBytes)
   265  		}
   266  	case "$$unsetOrMatches":
   267  		if actual.Validate() != nil {
   268  			return nil
   269  		}
   270  
   271  		// $$unsetOrMatches doesn't modify the nesting level or the key path so we can propagate the context to the
   272  		// comparison function without changing anything.
   273  		return verifyValuesMatchInner(ctx, assertionVal, actual)
   274  	case "$$sessionLsid":
   275  		sess, err := entities(ctx).session(assertionVal.StringValue())
   276  		if err != nil {
   277  			return err
   278  		}
   279  
   280  		expectedID := sess.ID()
   281  		actualID, ok := actual.DocumentOK()
   282  		if !ok {
   283  			return fmt.Errorf("expected document value for a $$sessionLsid assertion, but got a %s", actual.Type)
   284  		}
   285  		if !bytes.Equal(expectedID, actualID) {
   286  			return fmt.Errorf("expected lsid %v, got %v", expectedID, actualID)
   287  		}
   288  	case "$$lte":
   289  		if assertionVal.Type != bsontype.Int32 && assertionVal.Type != bsontype.Int64 {
   290  			return fmt.Errorf("expected assertionVal to be an Int32 or Int64 but got a %s", assertionVal.Type)
   291  		}
   292  		if actual.Type != bsontype.Int32 && actual.Type != bsontype.Int64 {
   293  			return fmt.Errorf("expected value to be an Int32 or Int64 but got a %s", actual.Type)
   294  		}
   295  
   296  		// Numeric values can be compared even if their types are different (e.g. if expected is an int32 and actual
   297  		// is an int64).
   298  		expectedInt64 := assertionVal.AsInt64()
   299  		actualInt64 := actual.AsInt64()
   300  		if actualInt64 > expectedInt64 {
   301  			return fmt.Errorf("expected numeric value %d to be less than or equal %d", actualInt64, expectedInt64)
   302  		}
   303  		return nil
   304  	case "$$matchAsDocument":
   305  		var actualDoc bson.Raw
   306  		str, ok := actual.StringValueOK()
   307  		if !ok {
   308  			return fmt.Errorf("expected value to be a string but got a %s", actual.Type)
   309  		}
   310  
   311  		if err := bson.UnmarshalExtJSON([]byte(str), true, &actualDoc); err != nil {
   312  			return fmt.Errorf("error unmarshalling string as document: %v", err)
   313  		}
   314  
   315  		if err := verifyValuesMatch(ctx, assertionVal, documentToRawValue(actualDoc), extraKeysAllowed); err != nil {
   316  			return fmt.Errorf("error matching $$matchAsRoot assertion: %v", err)
   317  		}
   318  	case "$$matchAsRoot":
   319  		// Treat the actual value as a root-level document that can have extra keys that are not subject to
   320  		// the matching rules.
   321  		if err := verifyValuesMatch(ctx, assertionVal, actual, extraKeysRootMatchAllowed); err != nil {
   322  			return fmt.Errorf("error matching $$matchAsRoot assertion: %v", err)
   323  		}
   324  	default:
   325  		return fmt.Errorf("unrecognized special matching assertion %q", assertion)
   326  	}
   327  
   328  	return nil
   329  }
   330  
   331  func requiresSpecialMatching(doc bson.Raw) bool {
   332  	elems, _ := doc.Elements()
   333  	return len(elems) == 1 && strings.HasPrefix(elems[0].Key(), "$$")
   334  }
   335  
   336  func getTypesArray(val bson.RawValue) ([]bsontype.Type, error) {
   337  	switch val.Type {
   338  	case bsontype.String:
   339  		convertedType, err := convertStringToBSONType(val.StringValue())
   340  		if err != nil {
   341  			return nil, err
   342  		}
   343  
   344  		return []bsontype.Type{convertedType}, nil
   345  	case bsontype.Array:
   346  		var typeStrings []string
   347  		if err := val.Unmarshal(&typeStrings); err != nil {
   348  			return nil, fmt.Errorf("error unmarshalling to slice of strings: %v", err)
   349  		}
   350  
   351  		var types []bsontype.Type
   352  		for _, typeStr := range typeStrings {
   353  			convertedType, err := convertStringToBSONType(typeStr)
   354  			if err != nil {
   355  				return nil, err
   356  			}
   357  
   358  			types = append(types, convertedType)
   359  		}
   360  		return types, nil
   361  	default:
   362  		return nil, fmt.Errorf("invalid type to convert to bsontype.Type slice: %s", val.Type)
   363  	}
   364  }
   365  
   366  func convertStringToBSONType(typeStr string) (bsontype.Type, error) {
   367  	switch typeStr {
   368  	case "double":
   369  		return bsontype.Double, nil
   370  	case "string":
   371  		return bsontype.String, nil
   372  	case "object":
   373  		return bsontype.EmbeddedDocument, nil
   374  	case "array":
   375  		return bsontype.Array, nil
   376  	case "binData":
   377  		return bsontype.Binary, nil
   378  	case "undefined":
   379  		return bsontype.Undefined, nil
   380  	case "objectId":
   381  		return bsontype.ObjectID, nil
   382  	case "bool":
   383  		return bsontype.Boolean, nil
   384  	case "date":
   385  		return bsontype.DateTime, nil
   386  	case "null":
   387  		return bsontype.Null, nil
   388  	case "regex":
   389  		return bsontype.Regex, nil
   390  	case "dbPointer":
   391  		return bsontype.DBPointer, nil
   392  	case "javascript":
   393  		return bsontype.JavaScript, nil
   394  	case "symbol":
   395  		return bsontype.Symbol, nil
   396  	case "javascriptWithScope":
   397  		return bsontype.CodeWithScope, nil
   398  	case "int":
   399  		return bsontype.Int32, nil
   400  	case "timestamp":
   401  		return bsontype.Timestamp, nil
   402  	case "long":
   403  		return bsontype.Int64, nil
   404  	case "decimal":
   405  		return bsontype.Decimal128, nil
   406  	case "minKey":
   407  		return bsontype.MinKey, nil
   408  	case "maxKey":
   409  		return bsontype.MaxKey, nil
   410  	default:
   411  		return bsontype.Type(0), fmt.Errorf("unrecognized BSON type string %q", typeStr)
   412  	}
   413  }
   414  
   415  // newMatchingError creates an error to convey that BSON value comparison failed at the provided key path. If the
   416  // key path is empty (e.g. because the values being compared were not documents), the error message will contain the
   417  // phrase "top-level" instead of the path.
   418  func newMatchingError(keyPath, msg string, args ...interface{}) error {
   419  	fullMsg := fmt.Sprintf(msg, args...)
   420  	if keyPath == "" {
   421  		return fmt.Errorf("comparison error at top-level: %s", fullMsg)
   422  	}
   423  	return fmt.Errorf("comparison error at key %q: %s", keyPath, fullMsg)
   424  }
   425  

View as plain text