...

Source file src/go.mongodb.org/mongo-driver/mongo/integration/unified/testrunner_operation.go

Documentation: go.mongodb.org/mongo-driver/mongo/integration/unified

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package unified
     8  
     9  import (
    10  	"context"
    11  	"fmt"
    12  	"strings"
    13  	"time"
    14  
    15  	"go.mongodb.org/mongo-driver/bson"
    16  	"go.mongodb.org/mongo-driver/mongo"
    17  	"go.mongodb.org/mongo-driver/mongo/integration/mtest"
    18  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    19  	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
    20  )
    21  
    22  // waitForEventTimeout is the amount of time to wait for an event to occur. The
    23  // maximum amount of time expected for this value is currently 10 seconds, which
    24  // is the amount of time that the driver will attempt to wait between streamable
    25  // heartbeats. Increase this value if a new maximum time is expected in another
    26  // operation.
    27  var waitForEventTimeout = 11 * time.Second
    28  
    29  type loopArgs struct {
    30  	Operations         []*operation `bson:"operations"`
    31  	ErrorsEntityID     string       `bson:"storeErrorsAsEntity"`
    32  	FailuresEntityID   string       `bson:"storeFailuresAsEntity"`
    33  	SuccessesEntityID  string       `bson:"storeSuccessesAsEntity"`
    34  	IterationsEntityID string       `bson:"storeIterationsAsEntity"`
    35  }
    36  
    37  func (lp *loopArgs) errorsStored() bool {
    38  	return lp.ErrorsEntityID != ""
    39  }
    40  
    41  func (lp *loopArgs) failuresStored() bool {
    42  	return lp.FailuresEntityID != ""
    43  }
    44  
    45  func (lp *loopArgs) successesStored() bool {
    46  	return lp.SuccessesEntityID != ""
    47  }
    48  
    49  func (lp *loopArgs) iterationsStored() bool {
    50  	return lp.IterationsEntityID != ""
    51  }
    52  
    53  func executeTestRunnerOperation(ctx context.Context, op *operation, loopDone <-chan struct{}) error {
    54  	args := op.Arguments
    55  
    56  	switch op.Name {
    57  	case "failPoint":
    58  		clientID := lookupString(args, "client")
    59  		client, err := entities(ctx).client(clientID)
    60  		if err != nil {
    61  			return err
    62  		}
    63  
    64  		fpDoc := args.Lookup("failPoint").Document()
    65  		if err := mtest.SetRawFailPoint(fpDoc, client.Client); err != nil {
    66  			return err
    67  		}
    68  		return addFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), client.Client)
    69  	case "targetedFailPoint":
    70  		sessID := lookupString(args, "session")
    71  		sess, err := entities(ctx).session(sessID)
    72  		if err != nil {
    73  			return err
    74  		}
    75  
    76  		clientSession := extractClientSession(sess)
    77  		if clientSession.PinnedServer == nil {
    78  			return fmt.Errorf("session is not pinned to a server")
    79  		}
    80  
    81  		targetHost := clientSession.PinnedServer.Addr.String()
    82  		fpDoc := args.Lookup("failPoint").Document()
    83  		commandFn := func(ctx context.Context, client *mongo.Client) error {
    84  			return mtest.SetRawFailPoint(fpDoc, client)
    85  		}
    86  
    87  		if err := runCommandOnHost(ctx, targetHost, commandFn); err != nil {
    88  			return err
    89  		}
    90  		return addTargetedFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), targetHost)
    91  	case "assertSessionTransactionState":
    92  		sessID := lookupString(args, "session")
    93  		sess, err := entities(ctx).session(sessID)
    94  		if err != nil {
    95  			return err
    96  		}
    97  
    98  		var expectedState session.TransactionState
    99  		switch stateStr := lookupString(args, "state"); stateStr {
   100  		case "none":
   101  			expectedState = session.None
   102  		case "starting":
   103  			expectedState = session.Starting
   104  		case "in_progress":
   105  			expectedState = session.InProgress
   106  		case "committed":
   107  			expectedState = session.Committed
   108  		case "aborted":
   109  			expectedState = session.Aborted
   110  		default:
   111  			return fmt.Errorf("unrecognized session state type %q", stateStr)
   112  		}
   113  
   114  		if actualState := extractClientSession(sess).TransactionState; actualState != expectedState {
   115  			return fmt.Errorf("expected session state %q does not match actual state %q", expectedState, actualState)
   116  		}
   117  		return nil
   118  	case "assertSessionPinned":
   119  		return verifySessionPinnedState(ctx, lookupString(args, "session"), true)
   120  	case "assertSessionUnpinned":
   121  		return verifySessionPinnedState(ctx, lookupString(args, "session"), false)
   122  	case "assertSameLsidOnLastTwoCommands":
   123  		return verifyLastTwoLsidsEqual(ctx, lookupString(args, "client"), true)
   124  	case "assertDifferentLsidOnLastTwoCommands":
   125  		return verifyLastTwoLsidsEqual(ctx, lookupString(args, "client"), false)
   126  	case "assertSessionDirty":
   127  		return verifySessionDirtyState(ctx, lookupString(args, "session"), true)
   128  	case "assertSessionNotDirty":
   129  		return verifySessionDirtyState(ctx, lookupString(args, "session"), false)
   130  	case "assertCollectionExists":
   131  		db := lookupString(args, "databaseName")
   132  		coll := lookupString(args, "collectionName")
   133  		return verifyCollectionExists(ctx, db, coll, true)
   134  	case "assertCollectionNotExists":
   135  		db := lookupString(args, "databaseName")
   136  		coll := lookupString(args, "collectionName")
   137  		return verifyCollectionExists(ctx, db, coll, false)
   138  	case "assertIndexExists":
   139  		db := lookupString(args, "databaseName")
   140  		coll := lookupString(args, "collectionName")
   141  		index := lookupString(args, "indexName")
   142  		return verifyIndexExists(ctx, db, coll, index, true)
   143  	case "assertIndexNotExists":
   144  		db := lookupString(args, "databaseName")
   145  		coll := lookupString(args, "collectionName")
   146  		index := lookupString(args, "indexName")
   147  		return verifyIndexExists(ctx, db, coll, index, false)
   148  	case "loop":
   149  		var unmarshaledArgs loopArgs
   150  		if err := bson.Unmarshal(args, &unmarshaledArgs); err != nil {
   151  			return fmt.Errorf("error unmarshalling arguments to loopArgs: %v", err)
   152  		}
   153  		return executeLoop(ctx, &unmarshaledArgs, loopDone)
   154  	case "assertNumberConnectionsCheckedOut":
   155  		clientID := lookupString(args, "client")
   156  		client, err := entities(ctx).client(clientID)
   157  		if err != nil {
   158  			return err
   159  		}
   160  
   161  		expected := int32(lookupInteger(args, "connections"))
   162  		actual := client.numberConnectionsCheckedOut()
   163  		if expected != actual {
   164  			return fmt.Errorf("expected %d connections to be checked out, got %d", expected, actual)
   165  		}
   166  		return nil
   167  	case "createEntities":
   168  		entitiesRaw, err := args.LookupErr("entities")
   169  		if err != nil {
   170  			return fmt.Errorf("'entities' argument not found in createEntities operation")
   171  		}
   172  
   173  		var createEntities []map[string]*entityOptions
   174  		if err := entitiesRaw.Unmarshal(&createEntities); err != nil {
   175  			return fmt.Errorf("error unmarshalling 'entities' argument to entityOptions: %v", err)
   176  		}
   177  
   178  		for idx, entity := range createEntities {
   179  			for entityType, entityOptions := range entity {
   180  				if entityType == "client" && hasOperationalFailpoint(ctx) {
   181  					entityOptions.setHeartbeatFrequencyMS(lowHeartbeatFrequency)
   182  				}
   183  
   184  				if err := entities(ctx).addEntity(ctx, entityType, entityOptions); err != nil {
   185  					return fmt.Errorf("error creating entity at index %d: %v", idx, err)
   186  				}
   187  			}
   188  		}
   189  		return nil
   190  	case "runOnThread":
   191  		operationRaw, err := args.LookupErr("operation")
   192  		if err != nil {
   193  			return fmt.Errorf("'operation' argument not found in runOnThread operation")
   194  		}
   195  		threadOp := new(operation)
   196  		if err := operationRaw.Unmarshal(threadOp); err != nil {
   197  			return fmt.Errorf("error unmarshaling 'operation' argument: %v", err)
   198  		}
   199  		thread := lookupString(args, "thread")
   200  		routine, ok := entities(ctx).routinesMap.Load(thread)
   201  		if !ok {
   202  			return fmt.Errorf("run on unknown thread: %s", thread)
   203  		}
   204  		routine.(*backgroundRoutine).addTask(threadOp.Name, func() error {
   205  			return threadOp.execute(ctx, loopDone)
   206  		})
   207  		return nil
   208  	case "waitForThread":
   209  		thread := lookupString(args, "thread")
   210  		routine, ok := entities(ctx).routinesMap.Load(thread)
   211  		if !ok {
   212  			return fmt.Errorf("wait for unknown thread: %s", thread)
   213  		}
   214  		return routine.(*backgroundRoutine).stop()
   215  	case "waitForEvent":
   216  		var wfeArgs waitForEventArguments
   217  		if err := bson.Unmarshal(op.Arguments, &wfeArgs); err != nil {
   218  			return fmt.Errorf("error unmarshalling event to waitForEventArguments: %v", err)
   219  		}
   220  
   221  		wfeCtx, cancel := context.WithTimeout(ctx, waitForEventTimeout)
   222  		defer cancel()
   223  
   224  		return waitForEvent(wfeCtx, wfeArgs)
   225  	default:
   226  		return fmt.Errorf("unrecognized testRunner operation %q", op.Name)
   227  	}
   228  }
   229  
   230  func executeLoop(ctx context.Context, args *loopArgs, loopDone <-chan struct{}) error {
   231  	// setup entities
   232  	entityMap := entities(ctx)
   233  	if args.errorsStored() {
   234  		if err := entityMap.addBSONArrayEntity(args.ErrorsEntityID); err != nil {
   235  			return err
   236  		}
   237  	}
   238  	if args.failuresStored() {
   239  		if err := entityMap.addBSONArrayEntity(args.FailuresEntityID); err != nil {
   240  			return err
   241  		}
   242  	}
   243  	if args.successesStored() {
   244  		if err := entityMap.addSuccessesEntity(args.SuccessesEntityID); err != nil {
   245  			return err
   246  		}
   247  	}
   248  	if args.iterationsStored() {
   249  		if err := entityMap.addIterationsEntity(args.IterationsEntityID); err != nil {
   250  			return err
   251  		}
   252  	}
   253  
   254  	for {
   255  		select {
   256  		case <-loopDone:
   257  			return nil
   258  		default:
   259  			if args.iterationsStored() {
   260  				if err := entityMap.incrementIterations(args.IterationsEntityID); err != nil {
   261  					return err
   262  				}
   263  			}
   264  			var loopErr error
   265  			for i, operation := range args.Operations {
   266  				if operation.Name == "loop" {
   267  					return fmt.Errorf("loop sub-operations should not include loop")
   268  				}
   269  				loopErr = operation.execute(ctx, loopDone)
   270  
   271  				// if the operation errors, stop this loop
   272  				if loopErr != nil {
   273  					// If StoreFailures or StoreErrors is set, continue looping, otherwise break
   274  					if !args.errorsStored() && !args.failuresStored() {
   275  						return fmt.Errorf("error running loop operation %v : %v", i, loopErr)
   276  					}
   277  					errDoc := bson.Raw(bsoncore.NewDocumentBuilder().
   278  						AppendString("error", loopErr.Error()).
   279  						AppendDouble("time", getSecondsSinceEpoch()).
   280  						Build())
   281  					var appendErr error
   282  					switch {
   283  					case !args.errorsStored(): // store errors as failures if storeErrorsAsEntity isn't specified
   284  						appendErr = entityMap.appendBSONArrayEntity(args.FailuresEntityID, errDoc)
   285  					case !args.failuresStored(): // store failures as errors if storeFailuressAsEntity isn't specified
   286  						appendErr = entityMap.appendBSONArrayEntity(args.ErrorsEntityID, errDoc)
   287  					// errors are test runner errors
   288  					// TODO GODRIVER-1950: use error types to determine error vs failure instead of depending on the
   289  					// TODO fact that operation.execute prepends "execution failed" to test runner errors
   290  					case strings.Contains(loopErr.Error(), "execution failed: "):
   291  						appendErr = entityMap.appendBSONArrayEntity(args.ErrorsEntityID, errDoc)
   292  					// failures are if an operation returns an incorrect result or error
   293  					default:
   294  						appendErr = entityMap.appendBSONArrayEntity(args.FailuresEntityID, errDoc)
   295  					}
   296  					if appendErr != nil {
   297  						return appendErr
   298  					}
   299  					// if a sub-operation errors, restart the loop
   300  					break
   301  				}
   302  				if args.successesStored() {
   303  					if err := entityMap.incrementSuccesses(args.SuccessesEntityID); err != nil {
   304  						return err
   305  					}
   306  				}
   307  			}
   308  		}
   309  	}
   310  }
   311  
   312  type waitForEventArguments struct {
   313  	ClientID string              `bson:"client"`
   314  	Event    map[string]bson.Raw `bson:"event"`
   315  	Count    int32               `bson:"count"`
   316  }
   317  
   318  // getServerDescriptionChangedEventCount will return "true" if a specific
   319  // server description change event has occurred, up to the description type.
   320  //
   321  // If the bson.Raw value is empty, then this function will only consider if a
   322  // serverDescriptionChangeEvent has occurred at all.
   323  //
   324  // If the bson.Raw contains newDescription and/or previousDescription, this
   325  // function will attempt to compare them to events up to the fields defined in
   326  // the UST specifications.
   327  func getServerDescriptionChangedEventCount(client *clientEntity, raw bson.Raw) int32 {
   328  	if len(raw) == 0 {
   329  		return 0
   330  	}
   331  
   332  	// If the document has no values, then we assume that the UST only
   333  	// intends to check that the event happened.
   334  	if values, _ := raw.Values(); len(values) == 0 {
   335  		return client.getEventCount(serverDescriptionChangedEvent)
   336  	}
   337  
   338  	var expectedEvt serverDescriptionChangedEventInfo
   339  	if err := bson.Unmarshal(raw, &expectedEvt); err != nil {
   340  		return 0
   341  	}
   342  
   343  	return client.getServerDescriptionChangedEventCount(expectedEvt)
   344  }
   345  
   346  // eventCompleted will check all of the events in the event map and return true if all of the events have at least the
   347  // specified number of occurrences. If the event map is empty, it will return true.
   348  func (args waitForEventArguments) eventCompleted(client *clientEntity) bool {
   349  	for rawEventType, eventDoc := range args.Event {
   350  		eventType, ok := monitoringEventTypeFromString(rawEventType)
   351  		if !ok {
   352  			return false
   353  		}
   354  
   355  		switch eventType {
   356  		case serverDescriptionChangedEvent:
   357  			if getServerDescriptionChangedEventCount(client, eventDoc) < args.Count {
   358  				return false
   359  			}
   360  		default:
   361  			if client.getEventCount(eventType) < args.Count {
   362  				return false
   363  			}
   364  		}
   365  	}
   366  
   367  	return true
   368  }
   369  
   370  func waitForEvent(ctx context.Context, args waitForEventArguments) error {
   371  	client, err := entities(ctx).client(args.ClientID)
   372  	if err != nil {
   373  		return err
   374  	}
   375  
   376  	for {
   377  		select {
   378  		case <-ctx.Done():
   379  			return fmt.Errorf("timed out waiting for event: %v", ctx.Err())
   380  		default:
   381  			if args.eventCompleted(client) {
   382  				return nil
   383  			}
   384  
   385  		}
   386  
   387  		time.Sleep(100 * time.Millisecond)
   388  	}
   389  }
   390  
   391  func extractClientSession(sess mongo.Session) *session.Client {
   392  	return sess.(mongo.XSession).ClientSession()
   393  }
   394  
   395  func verifySessionPinnedState(ctx context.Context, sessionID string, expectedPinned bool) error {
   396  	sess, err := entities(ctx).session(sessionID)
   397  	if err != nil {
   398  		return err
   399  	}
   400  
   401  	if isPinned := extractClientSession(sess).PinnedServer != nil; expectedPinned != isPinned {
   402  		return fmt.Errorf("session pinned state mismatch; expected to be pinned: %v, is pinned: %v", expectedPinned, isPinned)
   403  	}
   404  	return nil
   405  }
   406  
   407  func verifyLastTwoLsidsEqual(ctx context.Context, clientID string, expectedEqual bool) error {
   408  	client, err := entities(ctx).client(clientID)
   409  	if err != nil {
   410  		return err
   411  	}
   412  
   413  	allEvents := client.startedEvents()
   414  	if len(allEvents) < 2 {
   415  		return fmt.Errorf("client has recorded fewer than two command started events")
   416  	}
   417  	lastTwoEvents := allEvents[len(allEvents)-2:]
   418  
   419  	firstID, err := lastTwoEvents[0].Command.LookupErr("lsid")
   420  	if err != nil {
   421  		return fmt.Errorf("first command has no 'lsid' field: %v", client.started[0].Command)
   422  	}
   423  	secondID, err := lastTwoEvents[1].Command.LookupErr("lsid")
   424  	if err != nil {
   425  		return fmt.Errorf("first command has no 'lsid' field: %v", client.started[1].Command)
   426  	}
   427  
   428  	areEqual := firstID.Equal(secondID)
   429  	if expectedEqual && !areEqual {
   430  		return fmt.Errorf("expected last two lsids to be equal, but got %s and %s", firstID, secondID)
   431  	}
   432  	if !expectedEqual && areEqual {
   433  		return fmt.Errorf("expected last two lsids to be different but both were %s", firstID)
   434  	}
   435  	return nil
   436  }
   437  
   438  func verifySessionDirtyState(ctx context.Context, sessionID string, expectedDirty bool) error {
   439  	sess, err := entities(ctx).session(sessionID)
   440  	if err != nil {
   441  		return err
   442  	}
   443  
   444  	if isDirty := extractClientSession(sess).Dirty; expectedDirty != isDirty {
   445  		return fmt.Errorf("session dirty state mismatch; expected to be dirty: %v, is dirty: %v", expectedDirty, isDirty)
   446  	}
   447  	return nil
   448  }
   449  
   450  func verifyCollectionExists(ctx context.Context, dbName, collName string, expectedExists bool) error {
   451  	db := mtest.GlobalClient().Database(dbName)
   452  	collections, err := db.ListCollectionNames(ctx, bson.M{"name": collName})
   453  	if err != nil {
   454  		return fmt.Errorf("error running ListCollectionNames: %v", err)
   455  	}
   456  
   457  	if exists := len(collections) == 1; expectedExists != exists {
   458  		ns := fmt.Sprintf("%s.%s", dbName, collName)
   459  		return fmt.Errorf("collection existence mismatch; expected namespace %q to exist: %v, exists: %v", ns,
   460  			expectedExists, exists)
   461  	}
   462  	return nil
   463  }
   464  
   465  func verifyIndexExists(ctx context.Context, dbName, collName, indexName string, expectedExists bool) error {
   466  	iv := mtest.GlobalClient().Database(dbName).Collection(collName).Indexes()
   467  	cursor, err := iv.List(ctx)
   468  	if err != nil {
   469  		return fmt.Errorf("error running IndexView.List: %v", err)
   470  	}
   471  	defer cursor.Close(ctx)
   472  
   473  	var exists bool
   474  	for cursor.Next(ctx) {
   475  		if lookupString(cursor.Current, "name") == indexName {
   476  			exists = true
   477  			break
   478  		}
   479  	}
   480  	if expectedExists != exists {
   481  		ns := fmt.Sprintf("%s.%s", dbName, collName)
   482  		return fmt.Errorf("index existence mismatch: expected index %q to exist in namespace %q: %v, exists: %v",
   483  			indexName, ns, expectedExists, exists)
   484  	}
   485  	return nil
   486  }
   487  

View as plain text