1
2
3
4
5
6
7 package unified
8
9 import (
10 "bytes"
11 "context"
12 "errors"
13 "fmt"
14
15 "go.mongodb.org/mongo-driver/bson"
16 "go.mongodb.org/mongo-driver/bson/bsontype"
17 "go.mongodb.org/mongo-driver/bson/primitive"
18 "go.mongodb.org/mongo-driver/event"
19 )
20
21 type commandMonitoringEvent struct {
22 CommandStartedEvent *struct {
23 Command bson.Raw `bson:"command"`
24 CommandName *string `bson:"commandName"`
25 DatabaseName *string `bson:"databaseName"`
26 HasServerConnectionID *bool `bson:"hasServerConnectionId"`
27 HasServiceID *bool `bson:"hasServiceId"`
28 } `bson:"commandStartedEvent"`
29
30 CommandSucceededEvent *struct {
31 CommandName *string `bson:"commandName"`
32 DatabaseName *string `bson:"databaseName"`
33 Reply bson.Raw `bson:"reply"`
34 HasServerConnectionID *bool `bson:"hasServerConnectionId"`
35 HasServiceID *bool `bson:"hasServiceId"`
36 } `bson:"commandSucceededEvent"`
37
38 CommandFailedEvent *struct {
39 CommandName *string `bson:"commandName"`
40 DatabaseName *string `bson:"databaseName"`
41 HasServerConnectionID *bool `bson:"hasServerConnectionId"`
42 HasServiceID *bool `bson:"hasServiceId"`
43 } `bson:"commandFailedEvent"`
44 }
45
46 type cmapEvent struct {
47 ConnectionCreatedEvent *struct{} `bson:"connectionCreatedEvent"`
48
49 ConnectionReadyEvent *struct{} `bson:"connectionReadyEvent"`
50
51 ConnectionClosedEvent *struct {
52 Reason *string `bson:"reason"`
53 } `bson:"connectionClosedEvent"`
54
55 ConnectionCheckedOutEvent *struct{} `bson:"connectionCheckedOutEvent"`
56
57 ConnectionCheckOutFailedEvent *struct {
58 Reason *string `bson:"reason"`
59 } `bson:"connectionCheckOutFailedEvent"`
60
61 ConnectionCheckedInEvent *struct{} `bson:"connectionCheckedInEvent"`
62
63 PoolClearedEvent *struct {
64 HasServiceID *bool `bson:"hasServiceId"`
65 InterruptInUseConnections *bool `bson:"interruptInUseConnections"`
66 } `bson:"poolClearedEvent"`
67 }
68
69 type sdamEvent struct {
70 ServerDescriptionChangedEvent *struct {
71 NewDescription *struct {
72 Type *string `bson:"type"`
73 } `bson:"newDescription"`
74
75 PreviousDescription *struct {
76 Type *string `bson:"type"`
77 } `bson:"previousDescription"`
78 } `bson:"serverDescriptionChangedEvent"`
79
80 ServerHeartbeatStartedEvent *struct {
81 Awaited *bool `bson:"awaited"`
82 } `bson:"serverHeartbeatStartedEvent"`
83
84 ServerHeartbeatSucceededEvent *struct {
85 Awaited *bool `bson:"awaited"`
86 } `bson:"serverHeartbeatSucceededEvent"`
87
88 ServerHeartbeatFailedEvent *struct {
89 Awaited *bool `bson:"awaited"`
90 } `bson:"serverHeartbeatFailedEvent"`
91
92 TopologyDescriptionChangedEvent *struct{} `bson:"topologyDescriptionChangedEvent"`
93 }
94
95 type expectedEvents struct {
96 ClientID string `bson:"client"`
97 CommandEvents []commandMonitoringEvent
98 CMAPEvents []cmapEvent
99 SDAMEvents []sdamEvent
100 IgnoreExtraEvents *bool
101 }
102
103 var _ bson.Unmarshaler = (*expectedEvents)(nil)
104
105 func (e *expectedEvents) UnmarshalBSON(data []byte) error {
106
107
108
109 var temp struct {
110 ClientID string `bson:"client"`
111 EventType string `bson:"eventType"`
112 Events bson.RawValue `bson:"events"`
113 IgnoreExtraEvents *bool `bson:"ignoreExtraEvents"`
114 Extra map[string]interface{} `bson:",inline"`
115 }
116 if err := bson.Unmarshal(data, &temp); err != nil {
117 return fmt.Errorf("error unmarshalling to temporary expectedEvents object: %w", err)
118 }
119 if len(temp.Extra) > 0 {
120 return fmt.Errorf("unrecognized fields for expectedEvents: %v", temp.Extra)
121 }
122
123 e.ClientID = temp.ClientID
124 if temp.Events.Type != bsontype.Array {
125 return fmt.Errorf("expected 'events' to be an array but got a %q", temp.Events.Type)
126 }
127
128 var target interface{}
129 switch temp.EventType {
130 case "command", "":
131 target = &e.CommandEvents
132 case "cmap":
133 target = &e.CMAPEvents
134 case "sdam":
135 target = &e.SDAMEvents
136 default:
137 return fmt.Errorf("unrecognized 'eventType' value for expectedEvents: %q", temp.EventType)
138 }
139
140 if err := temp.Events.Unmarshal(target); err != nil {
141 return fmt.Errorf("error unmarshalling events array: %w", err)
142 }
143
144 if temp.IgnoreExtraEvents != nil {
145 e.IgnoreExtraEvents = temp.IgnoreExtraEvents
146 }
147 return nil
148 }
149
150 func verifyEvents(ctx context.Context, expectedEvents *expectedEvents) error {
151 client, err := entities(ctx).client(expectedEvents.ClientID)
152 if err != nil {
153 return err
154 }
155
156 switch {
157 case expectedEvents.CommandEvents != nil:
158 return verifyCommandEvents(ctx, client, expectedEvents)
159 case expectedEvents.CMAPEvents != nil:
160 return verifyCMAPEvents(client, expectedEvents)
161 case expectedEvents.SDAMEvents != nil:
162 return verifySDAMEvents(client, expectedEvents)
163 }
164 return nil
165 }
166
167 func verifyCommandEvents(ctx context.Context, client *clientEntity, expectedEvents *expectedEvents) error {
168 started := client.startedEvents()
169 succeeded := client.succeededEvents()
170 failed := client.failedEvents()
171
172
173 if len(expectedEvents.CommandEvents) == 0 && (len(started)+len(succeeded)+len(failed) != 0) {
174 return fmt.Errorf("expected no events to be sent but got %s", stringifyEventsForClient(client))
175 }
176
177 for idx, evt := range expectedEvents.CommandEvents {
178 switch {
179 case evt.CommandStartedEvent != nil:
180 if len(started) == 0 {
181 return newEventVerificationError(idx, client, "no CommandStartedEvent published")
182 }
183
184 actual := started[0]
185 started = started[1:]
186
187 expected := evt.CommandStartedEvent
188 if expected.CommandName != nil && *expected.CommandName != actual.CommandName {
189 return newEventVerificationError(idx, client, "expected command name %q, got %q", *expected.CommandName,
190 actual.CommandName)
191 }
192 if expected.DatabaseName != nil && *expected.DatabaseName != actual.DatabaseName {
193 return newEventVerificationError(idx, client, "expected database name %q, got %q", *expected.DatabaseName,
194 actual.DatabaseName)
195 }
196 if expected.Command != nil {
197 expectedDoc := documentToRawValue(expected.Command)
198 actualDoc := documentToRawValue(actual.Command)
199
200
201
202
203
204 if len(actual.Command) == 0 {
205 emptyDoc := []byte{5, 0, 0, 0, 0}
206 actualDoc = bson.RawValue{Type: bsontype.EmbeddedDocument, Value: emptyDoc}
207 }
208
209 if err := verifyValuesMatch(ctx, expectedDoc, actualDoc, true); err != nil {
210 return newEventVerificationError(idx, client, "error comparing command documents: %v", err)
211 }
212 }
213 if expected.HasServiceID != nil {
214 if err := verifyServiceID(*expected.HasServiceID, actual.ServiceID); err != nil {
215 return newEventVerificationError(idx, client, "error verifying serviceID: %v", err)
216 }
217 }
218 if expected.HasServerConnectionID != nil {
219 if err := verifyServerConnectionID(*expected.HasServerConnectionID, actual.ServerConnectionID64); err != nil {
220 return newEventVerificationError(idx, client, "error verifying serverConnectionID: %v", err)
221 }
222 }
223 case evt.CommandSucceededEvent != nil:
224 if len(succeeded) == 0 {
225 return newEventVerificationError(idx, client, "no CommandSucceededEvent published")
226 }
227
228 actual := succeeded[0]
229 succeeded = succeeded[1:]
230
231 expected := evt.CommandSucceededEvent
232 if expected.CommandName != nil && *expected.CommandName != actual.CommandName {
233 return newEventVerificationError(idx, client, "expected command name %q, got %q", *expected.CommandName,
234 actual.CommandName)
235 }
236 if expected.DatabaseName != nil && *expected.DatabaseName != actual.DatabaseName {
237 return newEventVerificationError(idx, client, "expected database name %q, got %q", *expected.DatabaseName,
238 actual.DatabaseName)
239 }
240 if expected.Reply != nil {
241 expectedDoc := documentToRawValue(expected.Reply)
242 actualDoc := documentToRawValue(actual.Reply)
243
244
245
246
247
248 if len(actual.Reply) == 0 {
249 emptyDoc := []byte{5, 0, 0, 0, 0}
250 actualDoc = bson.RawValue{Type: bsontype.EmbeddedDocument, Value: emptyDoc}
251 }
252
253 if err := verifyValuesMatch(ctx, expectedDoc, actualDoc, true); err != nil {
254 return newEventVerificationError(idx, client, "error comparing reply documents: %v", err)
255 }
256 }
257 if expected.HasServiceID != nil {
258 if err := verifyServiceID(*expected.HasServiceID, actual.ServiceID); err != nil {
259 return newEventVerificationError(idx, client, "error verifying serviceID: %v", err)
260 }
261 }
262 if expected.HasServerConnectionID != nil {
263 if err := verifyServerConnectionID(*expected.HasServerConnectionID, actual.ServerConnectionID64); err != nil {
264 return newEventVerificationError(idx, client, "error verifying serverConnectionID: %v", err)
265 }
266 }
267 case evt.CommandFailedEvent != nil:
268 if len(failed) == 0 {
269 return newEventVerificationError(idx, client, "no CommandFailedEvent published")
270 }
271
272 actual := failed[0]
273 failed = failed[1:]
274
275 expected := evt.CommandFailedEvent
276 if expected.CommandName != nil && *expected.CommandName != actual.CommandName {
277 return newEventVerificationError(idx, client, "expected command name %q, got %q", *expected.CommandName,
278 actual.CommandName)
279 }
280 if expected.DatabaseName != nil && *expected.DatabaseName != actual.DatabaseName {
281 return newEventVerificationError(idx, client, "expected database name %q, got %q", *expected.DatabaseName,
282 actual.DatabaseName)
283 }
284 if expected.HasServiceID != nil {
285 if err := verifyServiceID(*expected.HasServiceID, actual.ServiceID); err != nil {
286 return newEventVerificationError(idx, client, "error verifying serviceID: %v", err)
287 }
288 }
289 if expected.HasServerConnectionID != nil {
290 if err := verifyServerConnectionID(*expected.HasServerConnectionID, actual.ServerConnectionID64); err != nil {
291 return newEventVerificationError(idx, client, "error verifying serverConnectionID: %v", err)
292 }
293 }
294 default:
295 return newEventVerificationError(idx, client, "no expected event set on commandMonitoringEvent instance")
296 }
297 }
298
299
300 ignoreExtraEvents := expectedEvents.IgnoreExtraEvents != nil && *expectedEvents.IgnoreExtraEvents
301 if !ignoreExtraEvents && (len(started) > 0 || len(succeeded) > 0 || len(failed) > 0) {
302 return fmt.Errorf("extra events published; all events for client: %s", stringifyEventsForClient(client))
303 }
304 return nil
305 }
306
307 func verifyCMAPEvents(client *clientEntity, expectedEvents *expectedEvents) error {
308 pooled := client.poolEvents()
309 if len(expectedEvents.CMAPEvents) == 0 && len(pooled) != 0 {
310 return fmt.Errorf("expected no cmap events to be sent but got %s", stringifyEventsForClient(client))
311 }
312
313 for idx, evt := range expectedEvents.CMAPEvents {
314 var err error
315
316 switch {
317 case evt.ConnectionCreatedEvent != nil:
318 if _, pooled, err = getNextPoolEvent(pooled, event.ConnectionCreated); err != nil {
319 return newEventVerificationError(idx, client, err.Error())
320 }
321 case evt.ConnectionReadyEvent != nil:
322 if _, pooled, err = getNextPoolEvent(pooled, event.ConnectionReady); err != nil {
323 return newEventVerificationError(idx, client, err.Error())
324 }
325 case evt.ConnectionClosedEvent != nil:
326 var actual *event.PoolEvent
327 if actual, pooled, err = getNextPoolEvent(pooled, event.ConnectionClosed); err != nil {
328 return newEventVerificationError(idx, client, err.Error())
329 }
330
331 if expectedReason := evt.ConnectionClosedEvent.Reason; expectedReason != nil {
332 if *expectedReason != actual.Reason {
333 return newEventVerificationError(idx, client, "expected reason %q, got %q", *expectedReason, actual.Reason)
334 }
335 }
336 case evt.ConnectionCheckedOutEvent != nil:
337 if _, pooled, err = getNextPoolEvent(pooled, event.GetSucceeded); err != nil {
338 return newEventVerificationError(idx, client, err.Error())
339 }
340 case evt.ConnectionCheckOutFailedEvent != nil:
341 var actual *event.PoolEvent
342 if actual, pooled, err = getNextPoolEvent(pooled, event.GetFailed); err != nil {
343 return newEventVerificationError(idx, client, err.Error())
344 }
345
346 if expectedReason := evt.ConnectionCheckOutFailedEvent.Reason; expectedReason != nil {
347 if *expectedReason != actual.Reason {
348 return newEventVerificationError(idx, client, "expected reason %q, got %q", *expectedReason, actual.Reason)
349 }
350 }
351 case evt.ConnectionCheckedInEvent != nil:
352 if _, pooled, err = getNextPoolEvent(pooled, event.ConnectionReturned); err != nil {
353 return newEventVerificationError(idx, client, err.Error())
354 }
355 case evt.PoolClearedEvent != nil:
356 var actual *event.PoolEvent
357 if actual, pooled, err = getNextPoolEvent(pooled, event.PoolCleared); err != nil {
358 return newEventVerificationError(idx, client, err.Error())
359 }
360 if expectServiceID := evt.PoolClearedEvent.HasServiceID; expectServiceID != nil {
361 if err := verifyServiceID(*expectServiceID, actual.ServiceID); err != nil {
362 return newEventVerificationError(idx, client, "error verifying serviceID: %v", err)
363 }
364 }
365 if expectInterruption := evt.PoolClearedEvent.InterruptInUseConnections; expectInterruption != nil && *expectInterruption != actual.Interruption {
366 return newEventVerificationError(idx, client, "expected interruptInUseConnections %v, got %v",
367 expectInterruption, actual.Interruption)
368 }
369 default:
370 return newEventVerificationError(idx, client, "no expected event set on cmapEvent instance")
371 }
372 }
373
374
375 ignoreExtraEvents := expectedEvents.IgnoreExtraEvents != nil && *expectedEvents.IgnoreExtraEvents
376 if !ignoreExtraEvents && len(pooled) > 0 {
377 return fmt.Errorf("extra events published; all events for client: %s", stringifyEventsForClient(client))
378 }
379 return nil
380 }
381
382 func getNextPoolEvent(events []*event.PoolEvent, expectedType string) (*event.PoolEvent, []*event.PoolEvent, error) {
383 if len(events) == 0 {
384 return nil, nil, fmt.Errorf("no %q event published", expectedType)
385 }
386
387 evt := events[0]
388 if evt.Type != expectedType {
389 return nil, nil, fmt.Errorf("expected pool event of type %q, got %q", expectedType, evt.Type)
390 }
391 return evt, events[1:], nil
392 }
393
394 func verifyServiceID(expectServiceID bool, serviceID *primitive.ObjectID) error {
395 if eventHasID := serviceID != nil; expectServiceID != eventHasID {
396 return fmt.Errorf("expected event to have server ID: %v, event has server ID %v", expectServiceID, serviceID)
397 }
398 return nil
399 }
400
401 func verifyServerConnectionID(expectedHasSCID bool, scid *int64) error {
402 if actualHasSCID := scid != nil; expectedHasSCID != actualHasSCID {
403 if expectedHasSCID {
404 return fmt.Errorf("expected event to have server connection ID, event has none")
405 }
406 return fmt.Errorf("expected event to have no server connection ID, got %d", *scid)
407 }
408 if expectedHasSCID && *scid <= 0 {
409 return fmt.Errorf("expected event to have a positive server connection ID, got %d", *scid)
410 }
411 return nil
412 }
413
414 func newEventVerificationError(idx int, client *clientEntity, msg string, args ...interface{}) error {
415 fullMsg := fmt.Sprintf(msg, args...)
416 return fmt.Errorf("event comparison failed at index %d: %s; all events found for client: %s", idx, fullMsg,
417 stringifyEventsForClient(client))
418 }
419
420 func stringifyEventsForClient(client *clientEntity) string {
421 str := bytes.NewBuffer(nil)
422
423 str.WriteString("\n\nStarted Events\n\n")
424 for _, evt := range client.startedEvents() {
425 str.WriteString(fmt.Sprintf("[%s] %s\n", evt.ConnectionID, evt.Command))
426 }
427
428 str.WriteString("\nSucceeded Events\n\n")
429 for _, evt := range client.succeededEvents() {
430 str.WriteString(fmt.Sprintf("[%s] CommandName: %s, Reply: %s\n", evt.ConnectionID, evt.CommandName, evt.Reply))
431 }
432
433 str.WriteString("\nFailed Events\n\n")
434 for _, evt := range client.failedEvents() {
435 str.WriteString(fmt.Sprintf("[%s] CommandName: %s, Failure: %s\n", evt.ConnectionID, evt.CommandName, evt.Failure))
436 }
437
438 str.WriteString("\nPool Events\n\n")
439 for _, evt := range client.poolEvents() {
440 str.WriteString(fmt.Sprintf("[%s] Event Type: %q\n", evt.Address, evt.Type))
441 }
442
443 return str.String()
444 }
445
446 func getNextServerDescriptionChangedEvent(
447 events []*event.ServerDescriptionChangedEvent,
448 ) (*event.ServerDescriptionChangedEvent, []*event.ServerDescriptionChangedEvent, error) {
449 if len(events) == 0 {
450 return nil, nil, errors.New("no server changed event published")
451 }
452
453 return events[0], events[1:], nil
454 }
455
456 func getNextServerHeartbeatStartedEvent(
457 events []*event.ServerHeartbeatStartedEvent,
458 ) (*event.ServerHeartbeatStartedEvent, []*event.ServerHeartbeatStartedEvent, error) {
459 if len(events) == 0 {
460 return nil, nil, errors.New("no heartbeat started event published")
461 }
462
463 return events[0], events[1:], nil
464 }
465
466 func getNextServerHeartbeatSucceededEvent(
467 events []*event.ServerHeartbeatSucceededEvent,
468 ) (*event.ServerHeartbeatSucceededEvent, []*event.ServerHeartbeatSucceededEvent, error) {
469 if len(events) == 0 {
470 return nil, nil, errors.New("no heartbeat succeeded event published")
471 }
472
473 return events[0], events[:1], nil
474 }
475
476 func getNextServerHeartbeatFailedEvent(
477 events []*event.ServerHeartbeatFailedEvent,
478 ) (*event.ServerHeartbeatFailedEvent, []*event.ServerHeartbeatFailedEvent, error) {
479 if len(events) == 0 {
480 return nil, nil, errors.New("no heartbeat failed event published")
481 }
482
483 return events[0], events[:1], nil
484 }
485
486 func getNextTopologyDescriptionChangedEvent(
487 events []*event.TopologyDescriptionChangedEvent,
488 ) (*event.TopologyDescriptionChangedEvent, []*event.TopologyDescriptionChangedEvent, error) {
489 if len(events) == 0 {
490 return nil, nil, errors.New("no topology description changed event published")
491 }
492
493 return events[0], events[:1], nil
494 }
495
496 func verifySDAMEvents(client *clientEntity, expectedEvents *expectedEvents) error {
497 var (
498 changed = client.serverDescriptionChanged
499 started = client.serverHeartbeatStartedEvent
500 succeeded = client.serverHeartbeatSucceeded
501 failed = client.serverHeartbeatFailedEvent
502 tchanged = client.topologyDescriptionChanged
503 )
504
505 vol := func() int { return len(changed) + len(started) + len(succeeded) + len(failed) + len(tchanged) }
506
507 if len(expectedEvents.SDAMEvents) == 0 && vol() != 0 {
508 return fmt.Errorf("expected no sdam events to be sent but got %s", stringifyEventsForClient(client))
509 }
510
511 for idx, evt := range expectedEvents.SDAMEvents {
512 var err error
513
514 switch {
515 case evt.ServerDescriptionChangedEvent != nil:
516 var got *event.ServerDescriptionChangedEvent
517 if got, changed, err = getNextServerDescriptionChangedEvent(changed); err != nil {
518 return newEventVerificationError(idx, client, err.Error())
519 }
520
521 prevDesc := evt.ServerDescriptionChangedEvent.NewDescription
522
523 var wantPrevDesc string
524 if prevDesc != nil && prevDesc.Type != nil {
525 wantPrevDesc = *prevDesc.Type
526 }
527
528 gotPrevDesc := got.PreviousDescription.Kind.String()
529 if gotPrevDesc != wantPrevDesc {
530 return newEventVerificationError(idx, client,
531 "expected previous server description %q, got %q", wantPrevDesc, gotPrevDesc)
532 }
533
534 newDesc := evt.ServerDescriptionChangedEvent.PreviousDescription
535
536 var wantNewDesc string
537 if newDesc != nil && newDesc.Type != nil {
538 wantNewDesc = *newDesc.Type
539 }
540
541 gotNewDesc := got.NewDescription.Kind.String()
542 if gotNewDesc != wantNewDesc {
543 return newEventVerificationError(idx, client,
544 "expected new server description %q, got %q", wantNewDesc, gotNewDesc)
545 }
546 case evt.ServerHeartbeatStartedEvent != nil:
547 var got *event.ServerHeartbeatStartedEvent
548 if got, started, err = getNextServerHeartbeatStartedEvent(started); err != nil {
549 return newEventVerificationError(idx, client, err.Error())
550 }
551
552 if want := evt.ServerHeartbeatStartedEvent.Awaited; want != nil && *want != got.Awaited {
553 return newEventVerificationError(idx, client, "want awaited %v, got %v", *want, got.Awaited)
554 }
555 case evt.ServerHeartbeatSucceededEvent != nil:
556 var got *event.ServerHeartbeatSucceededEvent
557 if got, succeeded, err = getNextServerHeartbeatSucceededEvent(succeeded); err != nil {
558 return newEventVerificationError(idx, client, err.Error())
559 }
560
561 if want := evt.ServerHeartbeatSucceededEvent.Awaited; want != nil && *want != got.Awaited {
562 return newEventVerificationError(idx, client, "want awaited %v, got %v", *want, got.Awaited)
563 }
564 case evt.ServerHeartbeatFailedEvent != nil:
565 var got *event.ServerHeartbeatFailedEvent
566 if got, failed, err = getNextServerHeartbeatFailedEvent(failed); err != nil {
567 return newEventVerificationError(idx, client, err.Error())
568 }
569
570 if want := evt.ServerHeartbeatFailedEvent.Awaited; want != nil && *want != got.Awaited {
571 return newEventVerificationError(idx, client, "want awaited %v, got %v", *want, got.Awaited)
572 }
573 case evt.TopologyDescriptionChangedEvent != nil:
574 if _, tchanged, err = getNextTopologyDescriptionChangedEvent(tchanged); err != nil {
575 return newEventVerificationError(idx, client, err.Error())
576 }
577 }
578 }
579
580
581 ignoreExtraEvents := expectedEvents.IgnoreExtraEvents != nil && *expectedEvents.IgnoreExtraEvents
582 if !ignoreExtraEvents && vol() > 0 {
583 return fmt.Errorf("extra sdam events published; all events for client: %s", stringifyEventsForClient(client))
584 }
585 return nil
586 }
587
View as plain text