1
2
3
4
5
6
7 package unified
8
9 import (
10 "context"
11 "fmt"
12 "strings"
13 "time"
14
15 "go.mongodb.org/mongo-driver/bson"
16 "go.mongodb.org/mongo-driver/mongo"
17 "go.mongodb.org/mongo-driver/mongo/integration/mtest"
18 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
19 "go.mongodb.org/mongo-driver/x/mongo/driver/session"
20 )
21
22
23
24
25
26
27 var waitForEventTimeout = 11 * time.Second
28
29 type loopArgs struct {
30 Operations []*operation `bson:"operations"`
31 ErrorsEntityID string `bson:"storeErrorsAsEntity"`
32 FailuresEntityID string `bson:"storeFailuresAsEntity"`
33 SuccessesEntityID string `bson:"storeSuccessesAsEntity"`
34 IterationsEntityID string `bson:"storeIterationsAsEntity"`
35 }
36
37 func (lp *loopArgs) errorsStored() bool {
38 return lp.ErrorsEntityID != ""
39 }
40
41 func (lp *loopArgs) failuresStored() bool {
42 return lp.FailuresEntityID != ""
43 }
44
45 func (lp *loopArgs) successesStored() bool {
46 return lp.SuccessesEntityID != ""
47 }
48
49 func (lp *loopArgs) iterationsStored() bool {
50 return lp.IterationsEntityID != ""
51 }
52
53 func executeTestRunnerOperation(ctx context.Context, op *operation, loopDone <-chan struct{}) error {
54 args := op.Arguments
55
56 switch op.Name {
57 case "failPoint":
58 clientID := lookupString(args, "client")
59 client, err := entities(ctx).client(clientID)
60 if err != nil {
61 return err
62 }
63
64 fpDoc := args.Lookup("failPoint").Document()
65 if err := mtest.SetRawFailPoint(fpDoc, client.Client); err != nil {
66 return err
67 }
68 return addFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), client.Client)
69 case "targetedFailPoint":
70 sessID := lookupString(args, "session")
71 sess, err := entities(ctx).session(sessID)
72 if err != nil {
73 return err
74 }
75
76 clientSession := extractClientSession(sess)
77 if clientSession.PinnedServer == nil {
78 return fmt.Errorf("session is not pinned to a server")
79 }
80
81 targetHost := clientSession.PinnedServer.Addr.String()
82 fpDoc := args.Lookup("failPoint").Document()
83 commandFn := func(ctx context.Context, client *mongo.Client) error {
84 return mtest.SetRawFailPoint(fpDoc, client)
85 }
86
87 if err := runCommandOnHost(ctx, targetHost, commandFn); err != nil {
88 return err
89 }
90 return addTargetedFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), targetHost)
91 case "assertSessionTransactionState":
92 sessID := lookupString(args, "session")
93 sess, err := entities(ctx).session(sessID)
94 if err != nil {
95 return err
96 }
97
98 var expectedState session.TransactionState
99 switch stateStr := lookupString(args, "state"); stateStr {
100 case "none":
101 expectedState = session.None
102 case "starting":
103 expectedState = session.Starting
104 case "in_progress":
105 expectedState = session.InProgress
106 case "committed":
107 expectedState = session.Committed
108 case "aborted":
109 expectedState = session.Aborted
110 default:
111 return fmt.Errorf("unrecognized session state type %q", stateStr)
112 }
113
114 if actualState := extractClientSession(sess).TransactionState; actualState != expectedState {
115 return fmt.Errorf("expected session state %q does not match actual state %q", expectedState, actualState)
116 }
117 return nil
118 case "assertSessionPinned":
119 return verifySessionPinnedState(ctx, lookupString(args, "session"), true)
120 case "assertSessionUnpinned":
121 return verifySessionPinnedState(ctx, lookupString(args, "session"), false)
122 case "assertSameLsidOnLastTwoCommands":
123 return verifyLastTwoLsidsEqual(ctx, lookupString(args, "client"), true)
124 case "assertDifferentLsidOnLastTwoCommands":
125 return verifyLastTwoLsidsEqual(ctx, lookupString(args, "client"), false)
126 case "assertSessionDirty":
127 return verifySessionDirtyState(ctx, lookupString(args, "session"), true)
128 case "assertSessionNotDirty":
129 return verifySessionDirtyState(ctx, lookupString(args, "session"), false)
130 case "assertCollectionExists":
131 db := lookupString(args, "databaseName")
132 coll := lookupString(args, "collectionName")
133 return verifyCollectionExists(ctx, db, coll, true)
134 case "assertCollectionNotExists":
135 db := lookupString(args, "databaseName")
136 coll := lookupString(args, "collectionName")
137 return verifyCollectionExists(ctx, db, coll, false)
138 case "assertIndexExists":
139 db := lookupString(args, "databaseName")
140 coll := lookupString(args, "collectionName")
141 index := lookupString(args, "indexName")
142 return verifyIndexExists(ctx, db, coll, index, true)
143 case "assertIndexNotExists":
144 db := lookupString(args, "databaseName")
145 coll := lookupString(args, "collectionName")
146 index := lookupString(args, "indexName")
147 return verifyIndexExists(ctx, db, coll, index, false)
148 case "loop":
149 var unmarshaledArgs loopArgs
150 if err := bson.Unmarshal(args, &unmarshaledArgs); err != nil {
151 return fmt.Errorf("error unmarshalling arguments to loopArgs: %v", err)
152 }
153 return executeLoop(ctx, &unmarshaledArgs, loopDone)
154 case "assertNumberConnectionsCheckedOut":
155 clientID := lookupString(args, "client")
156 client, err := entities(ctx).client(clientID)
157 if err != nil {
158 return err
159 }
160
161 expected := int32(lookupInteger(args, "connections"))
162 actual := client.numberConnectionsCheckedOut()
163 if expected != actual {
164 return fmt.Errorf("expected %d connections to be checked out, got %d", expected, actual)
165 }
166 return nil
167 case "createEntities":
168 entitiesRaw, err := args.LookupErr("entities")
169 if err != nil {
170 return fmt.Errorf("'entities' argument not found in createEntities operation")
171 }
172
173 var createEntities []map[string]*entityOptions
174 if err := entitiesRaw.Unmarshal(&createEntities); err != nil {
175 return fmt.Errorf("error unmarshalling 'entities' argument to entityOptions: %v", err)
176 }
177
178 for idx, entity := range createEntities {
179 for entityType, entityOptions := range entity {
180 if entityType == "client" && hasOperationalFailpoint(ctx) {
181 entityOptions.setHeartbeatFrequencyMS(lowHeartbeatFrequency)
182 }
183
184 if err := entities(ctx).addEntity(ctx, entityType, entityOptions); err != nil {
185 return fmt.Errorf("error creating entity at index %d: %v", idx, err)
186 }
187 }
188 }
189 return nil
190 case "runOnThread":
191 operationRaw, err := args.LookupErr("operation")
192 if err != nil {
193 return fmt.Errorf("'operation' argument not found in runOnThread operation")
194 }
195 threadOp := new(operation)
196 if err := operationRaw.Unmarshal(threadOp); err != nil {
197 return fmt.Errorf("error unmarshaling 'operation' argument: %v", err)
198 }
199 thread := lookupString(args, "thread")
200 routine, ok := entities(ctx).routinesMap.Load(thread)
201 if !ok {
202 return fmt.Errorf("run on unknown thread: %s", thread)
203 }
204 routine.(*backgroundRoutine).addTask(threadOp.Name, func() error {
205 return threadOp.execute(ctx, loopDone)
206 })
207 return nil
208 case "waitForThread":
209 thread := lookupString(args, "thread")
210 routine, ok := entities(ctx).routinesMap.Load(thread)
211 if !ok {
212 return fmt.Errorf("wait for unknown thread: %s", thread)
213 }
214 return routine.(*backgroundRoutine).stop()
215 case "waitForEvent":
216 var wfeArgs waitForEventArguments
217 if err := bson.Unmarshal(op.Arguments, &wfeArgs); err != nil {
218 return fmt.Errorf("error unmarshalling event to waitForEventArguments: %v", err)
219 }
220
221 wfeCtx, cancel := context.WithTimeout(ctx, waitForEventTimeout)
222 defer cancel()
223
224 return waitForEvent(wfeCtx, wfeArgs)
225 default:
226 return fmt.Errorf("unrecognized testRunner operation %q", op.Name)
227 }
228 }
229
230 func executeLoop(ctx context.Context, args *loopArgs, loopDone <-chan struct{}) error {
231
232 entityMap := entities(ctx)
233 if args.errorsStored() {
234 if err := entityMap.addBSONArrayEntity(args.ErrorsEntityID); err != nil {
235 return err
236 }
237 }
238 if args.failuresStored() {
239 if err := entityMap.addBSONArrayEntity(args.FailuresEntityID); err != nil {
240 return err
241 }
242 }
243 if args.successesStored() {
244 if err := entityMap.addSuccessesEntity(args.SuccessesEntityID); err != nil {
245 return err
246 }
247 }
248 if args.iterationsStored() {
249 if err := entityMap.addIterationsEntity(args.IterationsEntityID); err != nil {
250 return err
251 }
252 }
253
254 for {
255 select {
256 case <-loopDone:
257 return nil
258 default:
259 if args.iterationsStored() {
260 if err := entityMap.incrementIterations(args.IterationsEntityID); err != nil {
261 return err
262 }
263 }
264 var loopErr error
265 for i, operation := range args.Operations {
266 if operation.Name == "loop" {
267 return fmt.Errorf("loop sub-operations should not include loop")
268 }
269 loopErr = operation.execute(ctx, loopDone)
270
271
272 if loopErr != nil {
273
274 if !args.errorsStored() && !args.failuresStored() {
275 return fmt.Errorf("error running loop operation %v : %v", i, loopErr)
276 }
277 errDoc := bson.Raw(bsoncore.NewDocumentBuilder().
278 AppendString("error", loopErr.Error()).
279 AppendDouble("time", getSecondsSinceEpoch()).
280 Build())
281 var appendErr error
282 switch {
283 case !args.errorsStored():
284 appendErr = entityMap.appendBSONArrayEntity(args.FailuresEntityID, errDoc)
285 case !args.failuresStored():
286 appendErr = entityMap.appendBSONArrayEntity(args.ErrorsEntityID, errDoc)
287
288
289
290 case strings.Contains(loopErr.Error(), "execution failed: "):
291 appendErr = entityMap.appendBSONArrayEntity(args.ErrorsEntityID, errDoc)
292
293 default:
294 appendErr = entityMap.appendBSONArrayEntity(args.FailuresEntityID, errDoc)
295 }
296 if appendErr != nil {
297 return appendErr
298 }
299
300 break
301 }
302 if args.successesStored() {
303 if err := entityMap.incrementSuccesses(args.SuccessesEntityID); err != nil {
304 return err
305 }
306 }
307 }
308 }
309 }
310 }
311
312 type waitForEventArguments struct {
313 ClientID string `bson:"client"`
314 Event map[string]bson.Raw `bson:"event"`
315 Count int32 `bson:"count"`
316 }
317
318
319
320
321
322
323
324
325
326
327 func getServerDescriptionChangedEventCount(client *clientEntity, raw bson.Raw) int32 {
328 if len(raw) == 0 {
329 return 0
330 }
331
332
333
334 if values, _ := raw.Values(); len(values) == 0 {
335 return client.getEventCount(serverDescriptionChangedEvent)
336 }
337
338 var expectedEvt serverDescriptionChangedEventInfo
339 if err := bson.Unmarshal(raw, &expectedEvt); err != nil {
340 return 0
341 }
342
343 return client.getServerDescriptionChangedEventCount(expectedEvt)
344 }
345
346
347
348 func (args waitForEventArguments) eventCompleted(client *clientEntity) bool {
349 for rawEventType, eventDoc := range args.Event {
350 eventType, ok := monitoringEventTypeFromString(rawEventType)
351 if !ok {
352 return false
353 }
354
355 switch eventType {
356 case serverDescriptionChangedEvent:
357 if getServerDescriptionChangedEventCount(client, eventDoc) < args.Count {
358 return false
359 }
360 default:
361 if client.getEventCount(eventType) < args.Count {
362 return false
363 }
364 }
365 }
366
367 return true
368 }
369
370 func waitForEvent(ctx context.Context, args waitForEventArguments) error {
371 client, err := entities(ctx).client(args.ClientID)
372 if err != nil {
373 return err
374 }
375
376 for {
377 select {
378 case <-ctx.Done():
379 return fmt.Errorf("timed out waiting for event: %v", ctx.Err())
380 default:
381 if args.eventCompleted(client) {
382 return nil
383 }
384
385 }
386
387 time.Sleep(100 * time.Millisecond)
388 }
389 }
390
391 func extractClientSession(sess mongo.Session) *session.Client {
392 return sess.(mongo.XSession).ClientSession()
393 }
394
395 func verifySessionPinnedState(ctx context.Context, sessionID string, expectedPinned bool) error {
396 sess, err := entities(ctx).session(sessionID)
397 if err != nil {
398 return err
399 }
400
401 if isPinned := extractClientSession(sess).PinnedServer != nil; expectedPinned != isPinned {
402 return fmt.Errorf("session pinned state mismatch; expected to be pinned: %v, is pinned: %v", expectedPinned, isPinned)
403 }
404 return nil
405 }
406
407 func verifyLastTwoLsidsEqual(ctx context.Context, clientID string, expectedEqual bool) error {
408 client, err := entities(ctx).client(clientID)
409 if err != nil {
410 return err
411 }
412
413 allEvents := client.startedEvents()
414 if len(allEvents) < 2 {
415 return fmt.Errorf("client has recorded fewer than two command started events")
416 }
417 lastTwoEvents := allEvents[len(allEvents)-2:]
418
419 firstID, err := lastTwoEvents[0].Command.LookupErr("lsid")
420 if err != nil {
421 return fmt.Errorf("first command has no 'lsid' field: %v", client.started[0].Command)
422 }
423 secondID, err := lastTwoEvents[1].Command.LookupErr("lsid")
424 if err != nil {
425 return fmt.Errorf("first command has no 'lsid' field: %v", client.started[1].Command)
426 }
427
428 areEqual := firstID.Equal(secondID)
429 if expectedEqual && !areEqual {
430 return fmt.Errorf("expected last two lsids to be equal, but got %s and %s", firstID, secondID)
431 }
432 if !expectedEqual && areEqual {
433 return fmt.Errorf("expected last two lsids to be different but both were %s", firstID)
434 }
435 return nil
436 }
437
438 func verifySessionDirtyState(ctx context.Context, sessionID string, expectedDirty bool) error {
439 sess, err := entities(ctx).session(sessionID)
440 if err != nil {
441 return err
442 }
443
444 if isDirty := extractClientSession(sess).Dirty; expectedDirty != isDirty {
445 return fmt.Errorf("session dirty state mismatch; expected to be dirty: %v, is dirty: %v", expectedDirty, isDirty)
446 }
447 return nil
448 }
449
450 func verifyCollectionExists(ctx context.Context, dbName, collName string, expectedExists bool) error {
451 db := mtest.GlobalClient().Database(dbName)
452 collections, err := db.ListCollectionNames(ctx, bson.M{"name": collName})
453 if err != nil {
454 return fmt.Errorf("error running ListCollectionNames: %v", err)
455 }
456
457 if exists := len(collections) == 1; expectedExists != exists {
458 ns := fmt.Sprintf("%s.%s", dbName, collName)
459 return fmt.Errorf("collection existence mismatch; expected namespace %q to exist: %v, exists: %v", ns,
460 expectedExists, exists)
461 }
462 return nil
463 }
464
465 func verifyIndexExists(ctx context.Context, dbName, collName, indexName string, expectedExists bool) error {
466 iv := mtest.GlobalClient().Database(dbName).Collection(collName).Indexes()
467 cursor, err := iv.List(ctx)
468 if err != nil {
469 return fmt.Errorf("error running IndexView.List: %v", err)
470 }
471 defer cursor.Close(ctx)
472
473 var exists bool
474 for cursor.Next(ctx) {
475 if lookupString(cursor.Current, "name") == indexName {
476 exists = true
477 break
478 }
479 }
480 if expectedExists != exists {
481 ns := fmt.Sprintf("%s.%s", dbName, collName)
482 return fmt.Errorf("index existence mismatch: expected index %q to exist in namespace %q: %v, exists: %v",
483 indexName, ns, expectedExists, exists)
484 }
485 return nil
486 }
487
View as plain text