1
2
3
4
5
6
7 package integration
8
9 import (
10 "bytes"
11 "encoding/hex"
12 "errors"
13 "fmt"
14 "strings"
15
16 "go.mongodb.org/mongo-driver/bson"
17 "go.mongodb.org/mongo-driver/bson/bsontype"
18 "go.mongodb.org/mongo-driver/event"
19 "go.mongodb.org/mongo-driver/internal/assert"
20 "go.mongodb.org/mongo-driver/mongo/integration/mtest"
21 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
22 )
23
24
25
26 func numberFromValue(mt *mtest.T, val bson.RawValue) int64 {
27 mt.Helper()
28
29 switch val.Type {
30 case bson.TypeInt32:
31 return int64(val.Int32())
32 case bson.TypeInt64:
33 return val.Int64()
34 case bson.TypeDouble:
35 return int64(val.Double())
36 default:
37 mt.Fatalf("unexpected type for number: %v", val.Type)
38 }
39
40 return 0
41 }
42
43 func compareNumberValues(mt *mtest.T, key string, expected, actual bson.RawValue) error {
44 eInt := numberFromValue(mt, expected)
45 if eInt == 42 {
46 if actual.Type == bson.TypeNull {
47 return fmt.Errorf("expected non-null value for key %s, got null", key)
48 }
49 return nil
50 }
51
52 aInt := numberFromValue(mt, actual)
53 if eInt != aInt {
54 return fmt.Errorf("value mismatch for key %s; expected %s, got %s", key, expected, actual)
55 }
56 return nil
57 }
58
59
60
61
62 func compareValues(mt *mtest.T, key string, expected, actual bson.RawValue) error {
63 mt.Helper()
64
65 switch expected.Type {
66 case bson.TypeInt32, bson.TypeInt64, bson.TypeDouble:
67 if err := compareNumberValues(mt, key, expected, actual); err != nil {
68 return err
69 }
70 return nil
71 case bson.TypeString:
72 val := expected.StringValue()
73 if val == "42" {
74 if actual.Type == bson.TypeNull {
75 return fmt.Errorf("expected non-null value for key %s, got null", key)
76 }
77 return nil
78 }
79
80 case bson.TypeEmbeddedDocument:
81 e := expected.Document()
82 if typeVal, err := e.LookupErr("$$type"); err == nil {
83
84
85 return checkValueType(mt, key, actual.Type, typeVal.StringValue())
86 }
87
88 a := actual.Document()
89 return compareDocsHelper(mt, e, a, key)
90 case bson.TypeArray:
91 e := expected.Array()
92 a := actual.Array()
93 return compareDocsHelper(mt, e, a, key)
94 }
95
96 if expected.Type != actual.Type {
97 return fmt.Errorf("type mismatch for key %s; expected %s, got %s", key, expected.Type, actual.Type)
98 }
99 if !bytes.Equal(expected.Value, actual.Value) {
100 return fmt.Errorf(
101 "value mismatch for key %s; expected %s (hex=%s), got %s (hex=%s)",
102 key,
103 expected.Value,
104 hex.EncodeToString(expected.Value),
105 actual.Value,
106 hex.EncodeToString(actual.Value))
107 }
108 return nil
109 }
110
111
112 func checkValueType(mt *mtest.T, key string, actual bsontype.Type, typeStr string) error {
113 mt.Helper()
114
115 var expected bsontype.Type
116 switch typeStr {
117 case "double":
118 expected = bsontype.Double
119 case "string":
120 expected = bsontype.String
121 case "object":
122 expected = bsontype.EmbeddedDocument
123 case "array":
124 expected = bsontype.Array
125 case "binData":
126 expected = bsontype.Binary
127 case "undefined":
128 expected = bsontype.Undefined
129 case "objectId":
130 expected = bsontype.ObjectID
131 case "boolean":
132 expected = bsontype.Boolean
133 case "date":
134 expected = bsontype.DateTime
135 case "null":
136 expected = bsontype.Null
137 case "regex":
138 expected = bsontype.Regex
139 case "dbPointer":
140 expected = bsontype.DBPointer
141 case "javascript":
142 expected = bsontype.JavaScript
143 case "symbol":
144 expected = bsontype.Symbol
145 case "javascriptWithScope":
146 expected = bsontype.CodeWithScope
147 case "int":
148 expected = bsontype.Int32
149 case "timestamp":
150 expected = bsontype.Timestamp
151 case "long":
152 expected = bsontype.Int64
153 case "decimal":
154 expected = bsontype.Decimal128
155 case "minKey":
156 expected = bsontype.MinKey
157 case "maxKey":
158 expected = bsontype.MaxKey
159 default:
160 mt.Fatalf("unrecognized type string: %v", typeStr)
161 }
162
163 if expected != actual {
164 return fmt.Errorf("BSON type mismatch for key %s; expected %s, got %s", key, expected, actual)
165 }
166 return nil
167 }
168
169
170 func compareDocsHelper(mt *mtest.T, expected, actual bson.Raw, prefix string) error {
171 mt.Helper()
172
173 eElems, err := expected.Elements()
174 assert.Nil(mt, err, "error getting expected elements: %v", err)
175
176 for _, e := range eElems {
177 eKey := e.Key()
178 fullKeyName := eKey
179 if prefix != "" {
180 fullKeyName = prefix + "." + eKey
181 }
182
183 aVal, err := actual.LookupErr(eKey)
184 if e.Value().Type == bson.TypeNull {
185
186 if errors.Is(err, bsoncore.ErrElementNotFound) {
187 continue
188 }
189 if err != nil {
190 return fmt.Errorf("expected key %q to be omitted but got error: %w", eKey, err)
191 }
192 return fmt.Errorf("expected key %q to be omitted but got %q", eKey, aVal)
193 }
194 if err != nil {
195 return fmt.Errorf("key %s not found in result", fullKeyName)
196 }
197
198 if err := compareValues(mt, fullKeyName, e.Value(), aVal); err != nil {
199 return err
200 }
201 }
202 return nil
203 }
204
205 func compareDocs(mt *mtest.T, expected, actual bson.Raw) error {
206 mt.Helper()
207 return compareDocsHelper(mt, expected, actual, "")
208 }
209
210 func checkExpectations(mt *mtest.T, expectations *[]*expectation, id0, id1 bson.Raw) {
211 mt.Helper()
212
213
214 if expectations == nil {
215 return
216 }
217
218
219 ignoredEvents := map[string]struct{}{
220 "configureFailPoint": {},
221 }
222 mt.FilterStartedEvents(func(evt *event.CommandStartedEvent) bool {
223
224 _, ok := ignoredEvents[evt.CommandName]
225 return !ok
226 })
227 mt.FilterSucceededEvents(func(evt *event.CommandSucceededEvent) bool {
228
229 _, ok := ignoredEvents[evt.CommandName]
230 return !ok
231 })
232 mt.FilterFailedEvents(func(evt *event.CommandFailedEvent) bool {
233
234 _, ok := ignoredEvents[evt.CommandName]
235 return !ok
236 })
237
238
239
240 if len(*expectations) == 0 {
241
242
243
244 numExpectedEvents := 0
245 bulkWriteTestName := "BulkWrite_on_server_that_doesn't_support_arrayFilters_with_arrayFilters_on_second_op"
246 if strings.HasSuffix(mt.Name(), bulkWriteTestName) {
247 numExpectedEvents = 1
248 }
249
250 numActualEvents := len(mt.GetAllStartedEvents())
251 assert.Equal(mt, numExpectedEvents, numActualEvents, "expected %d events to be sent, but got %d events",
252 numExpectedEvents, numActualEvents)
253 return
254 }
255
256 for idx, expectation := range *expectations {
257 var err error
258
259 if expectation.CommandStartedEvent != nil {
260 err = compareStartedEvent(mt, expectation, id0, id1)
261 }
262 if expectation.CommandSucceededEvent != nil {
263 err = compareSucceededEvent(mt, expectation)
264 }
265 if expectation.CommandFailedEvent != nil {
266 err = compareFailedEvent(mt, expectation)
267 }
268
269 assert.Nil(mt, err, "expectation comparison error at index %v: %s", idx, err)
270 }
271 }
272
273
274 func newMatchError(mt *mtest.T, expected bson.Raw, actual bson.Raw, format string, args ...interface{}) error {
275 mt.Helper()
276 msg := fmt.Sprintf(format, args...)
277 expectedJSON, err := bson.MarshalExtJSON(expected, true, false)
278 assert.Nil(mt, err, "error in MarshalExtJSON: %v", err)
279 actualJSON, err := bson.MarshalExtJSON(actual, true, false)
280 assert.Nil(mt, err, "error in MarshalExtJSON: %v", err)
281 return fmt.Errorf("%s\nExpected %s\nGot: %s", msg, string(expectedJSON), string(actualJSON))
282 }
283
284 func compareStartedEvent(mt *mtest.T, expectation *expectation, id0, id1 bson.Raw) error {
285 mt.Helper()
286
287 expected := expectation.CommandStartedEvent
288
289 if len(expected.Extra) > 0 {
290 return fmt.Errorf("unrecognized fields for CommandStartedEvent: %v", expected.Extra)
291 }
292
293 evt := mt.GetStartedEvent()
294 if evt == nil {
295 return errors.New("expected CommandStartedEvent, got nil")
296 }
297
298 if expected.CommandName != "" && expected.CommandName != evt.CommandName {
299 return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName)
300 }
301 if expected.DatabaseName != "" && expected.DatabaseName != evt.DatabaseName {
302 return fmt.Errorf("database name mismatch; expected %s, got %s", expected.DatabaseName, evt.DatabaseName)
303 }
304
305 eElems, err := expected.Command.Elements()
306 if err != nil {
307 return fmt.Errorf("error getting expected command elements: %s", err)
308 }
309
310 for _, elem := range eElems {
311 key := elem.Key()
312 val := elem.Value()
313
314 actualVal, err := evt.Command.LookupErr(key)
315
316
317 if val.Type == bson.TypeNull {
318
319 if errors.Is(err, bsoncore.ErrElementNotFound) {
320 continue
321 }
322 if err != nil {
323 return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got error: %v", key, err)
324 }
325 return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got %q", key, actualVal)
326 }
327 assert.Nil(mt, err, "expected command to contain key %q", key)
328
329 if key == "batchSize" {
330
331
332
333
334 continue
335 }
336
337 switch key {
338 case "lsid":
339 sessName := val.StringValue()
340 var expectedID bson.Raw
341 actualID := actualVal.Document()
342
343 switch sessName {
344 case "session0":
345 expectedID = id0
346 case "session1":
347 expectedID = id1
348 default:
349 return newMatchError(mt, expected.Command, evt.Command, "unrecognized session identifier in command document: %s", sessName)
350 }
351
352 if !bytes.Equal(expectedID, actualID) {
353 return newMatchError(mt, expected.Command, evt.Command, "session ID mismatch for session %s; expected %s, got %s", sessName, expectedID,
354 actualID)
355 }
356 default:
357 if err := compareValues(mt, key, val, actualVal); err != nil {
358 return newMatchError(mt, expected.Command, evt.Command, "%s", err)
359 }
360 }
361 }
362 return nil
363 }
364
365 func compareWriteErrors(mt *mtest.T, expected, actual bson.Raw) error {
366 mt.Helper()
367
368 expectedErrors, _ := expected.Values()
369 actualErrors, _ := actual.Values()
370
371 for i, expectedErrVal := range expectedErrors {
372 expectedErr := expectedErrVal.Document()
373 actualErr := actualErrors[i].Document()
374
375 eIdx := expectedErr.Lookup("index").Int32()
376 aIdx := actualErr.Lookup("index").Int32()
377 if eIdx != aIdx {
378 return fmt.Errorf("write error index mismatch at index %d; expected %d, got %d", i, eIdx, aIdx)
379 }
380
381 eCode := expectedErr.Lookup("code").Int32()
382 aCode := actualErr.Lookup("code").Int32()
383 if eCode != 42 && eCode != aCode {
384 return fmt.Errorf("write error code mismatch at index %d; expected %d, got %d", i, eCode, aCode)
385 }
386
387 eMsg := expectedErr.Lookup("errmsg").StringValue()
388 aMsg := actualErr.Lookup("errmsg").StringValue()
389 if eMsg == "" {
390 if aMsg == "" {
391 return fmt.Errorf("write error message mismatch at index %d; expected non-empty message, got empty", i)
392 }
393 return nil
394 }
395 if eMsg != aMsg {
396 return fmt.Errorf("write error message mismatch at index %d, expected %s, got %s", i, eMsg, aMsg)
397 }
398 }
399 return nil
400 }
401
402 func compareSucceededEvent(mt *mtest.T, expectation *expectation) error {
403 mt.Helper()
404
405 expected := expectation.CommandSucceededEvent
406 if len(expected.Extra) > 0 {
407 return fmt.Errorf("unrecognized fields for CommandSucceededEvent: %v", expected.Extra)
408 }
409 evt := mt.GetSucceededEvent()
410 if evt == nil {
411 return errors.New("expected CommandSucceededEvent, got nil")
412 }
413
414 if expected.CommandName != "" && expected.CommandName != evt.CommandName {
415 return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName)
416 }
417
418 eElems, err := expected.Reply.Elements()
419 if err != nil {
420 return fmt.Errorf("error getting expected reply elements: %s", err)
421 }
422
423 for _, elem := range eElems {
424 key := elem.Key()
425 val := elem.Value()
426 actualVal := evt.Reply.Lookup(key)
427
428 switch key {
429 case "writeErrors":
430 if err = compareWriteErrors(mt, val.Array(), actualVal.Array()); err != nil {
431 return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
432 }
433 default:
434 if err := compareValues(mt, key, val, actualVal); err != nil {
435 return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
436 }
437 }
438 }
439 return nil
440 }
441
442 func compareFailedEvent(mt *mtest.T, expectation *expectation) error {
443 mt.Helper()
444
445 expected := expectation.CommandFailedEvent
446 if len(expected.Extra) > 0 {
447 return fmt.Errorf("unrecognized fields for CommandFailedEvent: %v", expected.Extra)
448 }
449 evt := mt.GetFailedEvent()
450 if evt == nil {
451 return errors.New("expected CommandFailedEvent, got nil")
452 }
453
454 if expected.CommandName != "" && expected.CommandName != evt.CommandName {
455 return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName)
456 }
457 return nil
458 }
459
View as plain text