1
2
3
4
5
6
7 package integration
8
9 import (
10 "crypto/tls"
11 "errors"
12 "fmt"
13 "io/ioutil"
14 "math"
15 "os"
16 "path"
17 "strings"
18 "testing"
19 "time"
20
21 "go.mongodb.org/mongo-driver/bson"
22 "go.mongodb.org/mongo-driver/internal/assert"
23 "go.mongodb.org/mongo-driver/mongo"
24 "go.mongodb.org/mongo-driver/mongo/options"
25 "go.mongodb.org/mongo-driver/mongo/readconcern"
26 "go.mongodb.org/mongo-driver/mongo/readpref"
27 "go.mongodb.org/mongo-driver/mongo/writeconcern"
28 )
29
30 var (
31 awsAccessKeyID = os.Getenv("FLE_AWS_KEY")
32 awsSecretAccessKey = os.Getenv("FLE_AWS_SECRET")
33 awsTempAccessKeyID = os.Getenv("CSFLE_AWS_TEMP_ACCESS_KEY_ID")
34 awsTempSecretAccessKey = os.Getenv("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY")
35 awsTempSessionToken = os.Getenv("CSFLE_AWS_TEMP_SESSION_TOKEN")
36 azureTenantID = os.Getenv("FLE_AZURE_TENANTID")
37 azureClientID = os.Getenv("FLE_AZURE_CLIENTID")
38 azureClientSecret = os.Getenv("FLE_AZURE_CLIENTSECRET")
39 gcpEmail = os.Getenv("FLE_GCP_EMAIL")
40 gcpPrivateKey = os.Getenv("FLE_GCP_PRIVATEKEY")
41 tlsCAFileKMIP = os.Getenv("CSFLE_TLS_CA_FILE")
42 tlsClientCertificateKeyFileKMIP = os.Getenv("CSFLE_TLS_CLIENT_CERT_FILE")
43 )
44
45
46
47
48
49
50 func jsonFilesInDir(t testing.TB, dir string) []string {
51 t.Helper()
52
53 files := make([]string, 0)
54
55 entries, err := ioutil.ReadDir(dir)
56 assert.Nil(t, err, "unable to read json file: %v", err)
57
58 for _, entry := range entries {
59 if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
60 continue
61 }
62
63 files = append(files, entry.Name())
64 }
65
66 return files
67 }
68
69
70 func createClientOptions(t testing.TB, opts bson.Raw) *options.ClientOptions {
71 t.Helper()
72
73 clientOpts := options.Client()
74 elems, _ := opts.Elements()
75 for _, elem := range elems {
76 name := elem.Key()
77 opt := elem.Value()
78
79 switch name {
80 case "retryWrites":
81 clientOpts.SetRetryWrites(opt.Boolean())
82 case "w":
83 switch opt.Type {
84 case bson.TypeInt32:
85 w := int(opt.Int32())
86 clientOpts.SetWriteConcern(writeconcern.New(writeconcern.W(w)))
87 case bson.TypeDouble:
88 w := int(opt.Double())
89 clientOpts.SetWriteConcern(writeconcern.New(writeconcern.W(w)))
90 case bson.TypeString:
91 clientOpts.SetWriteConcern(writeconcern.New(writeconcern.WMajority()))
92 default:
93 t.Fatalf("unrecognized type for w client option: %v", opt.Type)
94 }
95 case "readConcernLevel":
96 clientOpts.SetReadConcern(readconcern.New(readconcern.Level(opt.StringValue())))
97 case "readPreference":
98 clientOpts.SetReadPreference(readPrefFromString(opt.StringValue()))
99 case "heartbeatFrequencyMS":
100 hf := convertValueToMilliseconds(t, opt)
101 clientOpts.SetHeartbeatInterval(hf)
102 case "retryReads":
103 clientOpts.SetRetryReads(opt.Boolean())
104 case "autoEncryptOpts":
105 clientOpts.SetAutoEncryptionOptions(createAutoEncryptionOptions(t, opt.Document()))
106 case "appname":
107 clientOpts.SetAppName(opt.StringValue())
108 case "connectTimeoutMS":
109 ct := convertValueToMilliseconds(t, opt)
110 clientOpts.SetConnectTimeout(ct)
111 case "serverSelectionTimeoutMS":
112 sst := convertValueToMilliseconds(t, opt)
113 clientOpts.SetServerSelectionTimeout(sst)
114 case "socketTimeoutMS":
115 st := convertValueToMilliseconds(t, opt)
116 clientOpts.SetSocketTimeout(st)
117 case "minPoolSize":
118 clientOpts.SetMinPoolSize(uint64(opt.AsInt64()))
119 case "maxPoolSize":
120 clientOpts.SetMaxPoolSize(uint64(opt.AsInt64()))
121 case "directConnection":
122 clientOpts.SetDirect(opt.Boolean())
123 default:
124 t.Fatalf("unrecognized client option: %v", name)
125 }
126 }
127
128 return clientOpts
129 }
130
131 func createAutoEncryptionOptions(t testing.TB, opts bson.Raw) *options.AutoEncryptionOptions {
132 t.Helper()
133
134 aeo := options.AutoEncryption()
135 var kvnsFound bool
136 elems, _ := opts.Elements()
137
138 for _, elem := range elems {
139 name := elem.Key()
140 opt := elem.Value()
141
142 switch name {
143 case "kmsProviders":
144 tlsConfigs := createTLSOptsMap(t, opt.Document())
145 aeo.SetKmsProviders(createKmsProvidersMap(t, opt.Document())).SetTLSConfig(tlsConfigs)
146 case "schemaMap":
147 var schemaMap map[string]interface{}
148 err := bson.Unmarshal(opt.Document(), &schemaMap)
149 if err != nil {
150 t.Fatalf("error creating schema map: %v", err)
151 }
152
153 aeo.SetSchemaMap(schemaMap)
154 case "keyVaultNamespace":
155 kvnsFound = true
156 aeo.SetKeyVaultNamespace(opt.StringValue())
157 case "bypassAutoEncryption":
158 aeo.SetBypassAutoEncryption(opt.Boolean())
159 case "encryptedFieldsMap":
160 var encryptedFieldsMap map[string]interface{}
161 err := bson.Unmarshal(opt.Document(), &encryptedFieldsMap)
162 if err != nil {
163 t.Fatalf("error creating encryptedFieldsMap: %v", err)
164 }
165 aeo.SetEncryptedFieldsMap(encryptedFieldsMap)
166 case "bypassQueryAnalysis":
167 aeo.SetBypassQueryAnalysis(opt.Boolean())
168 default:
169 t.Fatalf("unrecognized auto encryption option: %v", name)
170 }
171 }
172 if !kvnsFound {
173 aeo.SetKeyVaultNamespace("keyvault.datakeys")
174 }
175
176 return aeo
177 }
178
179 func createTLSOptsMap(t testing.TB, opts bson.Raw) map[string]*tls.Config {
180 t.Helper()
181
182 tlsMap := make(map[string]*tls.Config)
183 elems, _ := opts.Elements()
184
185 for _, elem := range elems {
186 provider := elem.Key()
187
188 if provider == "kmip" {
189 tlsOptsMap := map[string]interface{}{
190 "tlsCertificateKeyFile": tlsClientCertificateKeyFileKMIP,
191 "tlsCAFile": tlsCAFileKMIP,
192 }
193
194 cfg, err := options.BuildTLSConfig(tlsOptsMap)
195 if err != nil {
196 t.Fatalf("error building TLS config map: %v", err)
197 }
198
199 tlsMap["kmip"] = cfg
200 }
201 }
202 return tlsMap
203 }
204
205 func createKmsProvidersMap(t testing.TB, opts bson.Raw) map[string]map[string]interface{} {
206 t.Helper()
207
208 kmsMap := make(map[string]map[string]interface{})
209 elems, _ := opts.Elements()
210
211 for _, elem := range elems {
212 provider := elem.Key()
213 providerOpt := elem.Value()
214
215 switch provider {
216 case "aws":
217 awsMap := map[string]interface{}{
218 "accessKeyId": awsAccessKeyID,
219 "secretAccessKey": awsSecretAccessKey,
220 }
221 kmsMap["aws"] = awsMap
222 case "azure":
223 kmsMap["azure"] = map[string]interface{}{
224 "tenantId": azureTenantID,
225 "clientId": azureClientID,
226 "clientSecret": azureClientSecret,
227 }
228 case "gcp":
229 kmsMap["gcp"] = map[string]interface{}{
230 "email": gcpEmail,
231 "privateKey": gcpPrivateKey,
232 }
233 case "local":
234 _, key := providerOpt.Document().Lookup("key").Binary()
235 localMap := map[string]interface{}{
236 "key": key,
237 }
238 kmsMap["local"] = localMap
239 case "awsTemporary":
240 if awsTempAccessKeyID == "" {
241 t.Fatal("AWS temp access key ID not set")
242 }
243 if awsTempSecretAccessKey == "" {
244 t.Fatal("AWS temp secret access key not set")
245 }
246 if awsTempSessionToken == "" {
247 t.Fatal("AWS temp session token not set")
248 }
249 awsMap := map[string]interface{}{
250 "accessKeyId": awsTempAccessKeyID,
251 "secretAccessKey": awsTempSecretAccessKey,
252 "sessionToken": awsTempSessionToken,
253 }
254 kmsMap["aws"] = awsMap
255 case "awsTemporaryNoSessionToken":
256 if awsTempAccessKeyID == "" {
257 t.Fatal("AWS temp access key ID not set")
258 }
259 if awsTempSecretAccessKey == "" {
260 t.Fatal("AWS temp secret access key not set")
261 }
262 awsMap := map[string]interface{}{
263 "accessKeyId": awsTempAccessKeyID,
264 "secretAccessKey": awsTempSecretAccessKey,
265 }
266 kmsMap["aws"] = awsMap
267 case "kmip":
268 kmipMap := map[string]interface{}{
269 "endpoint": "localhost:5698",
270 }
271 kmsMap["kmip"] = kmipMap
272 default:
273 t.Fatalf("unrecognized KMS provider: %v", provider)
274 }
275 }
276
277 return kmsMap
278 }
279
280
281 func createSessionOptions(t testing.TB, opts bson.Raw) *options.SessionOptions {
282 t.Helper()
283
284 sessOpts := options.Session()
285 elems, _ := opts.Elements()
286 for _, elem := range elems {
287 name := elem.Key()
288 opt := elem.Value()
289
290 switch name {
291 case "causalConsistency":
292 sessOpts = sessOpts.SetCausalConsistency(opt.Boolean())
293 case "defaultTransactionOptions":
294 txnOpts := createTransactionOptions(t, opt.Document())
295 if txnOpts.ReadConcern != nil {
296 sessOpts.SetDefaultReadConcern(txnOpts.ReadConcern)
297 }
298 if txnOpts.ReadPreference != nil {
299 sessOpts.SetDefaultReadPreference(txnOpts.ReadPreference)
300 }
301 if txnOpts.WriteConcern != nil {
302 sessOpts.SetDefaultWriteConcern(txnOpts.WriteConcern)
303 }
304 if txnOpts.MaxCommitTime != nil {
305 sessOpts.SetDefaultMaxCommitTime(txnOpts.MaxCommitTime)
306 }
307 default:
308 t.Fatalf("unrecognized session option: %v", name)
309 }
310 }
311
312 return sessOpts
313 }
314
315
316 func createDatabaseOptions(t testing.TB, opts bson.Raw) *options.DatabaseOptions {
317 t.Helper()
318
319 do := options.Database()
320 elems, _ := opts.Elements()
321 for _, elem := range elems {
322 name := elem.Key()
323 opt := elem.Value()
324
325 switch name {
326 case "readConcern":
327 do.SetReadConcern(createReadConcern(opt))
328 case "writeConcern":
329 do.SetWriteConcern(createWriteConcern(t, opt))
330 default:
331 t.Fatalf("unrecognized database option: %v", name)
332 }
333 }
334
335 return do
336 }
337
338
339 func createCollectionOptions(t testing.TB, opts bson.Raw) *options.CollectionOptions {
340 t.Helper()
341
342 co := options.Collection()
343 elems, _ := opts.Elements()
344 for _, elem := range elems {
345 name := elem.Key()
346 opt := elem.Value()
347
348 switch name {
349 case "readConcern":
350 co.SetReadConcern(createReadConcern(opt))
351 case "writeConcern":
352 co.SetWriteConcern(createWriteConcern(t, opt))
353 case "readPreference":
354 co.SetReadPreference(createReadPref(opt))
355 default:
356 t.Fatalf("unrecognized collection option: %v", name)
357 }
358 }
359
360 return co
361 }
362
363
364 func createTransactionOptions(t testing.TB, opts bson.Raw) *options.TransactionOptions {
365 t.Helper()
366
367 txnOpts := options.Transaction()
368 elems, _ := opts.Elements()
369 for _, elem := range elems {
370 name := elem.Key()
371 opt := elem.Value()
372
373 switch name {
374 case "writeConcern":
375 txnOpts.SetWriteConcern(createWriteConcern(t, opt))
376 case "readPreference":
377 txnOpts.SetReadPreference(createReadPref(opt))
378 case "readConcern":
379 txnOpts.SetReadConcern(createReadConcern(opt))
380 case "maxCommitTimeMS":
381 t := time.Duration(opt.Int32()) * time.Millisecond
382 txnOpts.SetMaxCommitTime(&t)
383 default:
384 t.Fatalf("unrecognized transaction option: %v", opt)
385 }
386 }
387 return txnOpts
388 }
389
390
391 func createReadConcern(opt bson.RawValue) *readconcern.ReadConcern {
392 return readconcern.New(readconcern.Level(opt.Document().Lookup("level").StringValue()))
393 }
394
395
396 func createWriteConcern(t testing.TB, opt bson.RawValue) *writeconcern.WriteConcern {
397 wcDoc, ok := opt.DocumentOK()
398 if !ok {
399 return nil
400 }
401
402 var opts []writeconcern.Option
403 elems, _ := wcDoc.Elements()
404 for _, elem := range elems {
405 key := elem.Key()
406 val := elem.Value()
407
408 switch key {
409 case "wtimeout":
410 wtimeout := convertValueToMilliseconds(t, val)
411 opts = append(opts, writeconcern.WTimeout(wtimeout))
412 case "j":
413 opts = append(opts, writeconcern.J(val.Boolean()))
414 case "w":
415 switch val.Type {
416 case bson.TypeString:
417 if val.StringValue() != "majority" {
418 break
419 }
420 opts = append(opts, writeconcern.WMajority())
421 case bson.TypeInt32:
422 w := int(val.Int32())
423 opts = append(opts, writeconcern.W(w))
424 default:
425 t.Fatalf("unrecognized type for w: %v", val.Type)
426 }
427 default:
428 t.Fatalf("unrecognized write concern option: %v", key)
429 }
430 }
431 return writeconcern.New(opts...)
432 }
433
434
435
436 func readPrefFromString(s string) *readpref.ReadPref {
437 switch strings.ToLower(s) {
438 case "primary":
439 return readpref.Primary()
440 case "primarypreferred":
441 return readpref.PrimaryPreferred()
442 case "secondary":
443 return readpref.Secondary()
444 case "secondarypreferred":
445 return readpref.SecondaryPreferred()
446 case "nearest":
447 return readpref.Nearest()
448 }
449 return readpref.Primary()
450 }
451
452
453 func createReadPref(opt bson.RawValue) *readpref.ReadPref {
454 mode := opt.Document().Lookup("mode").StringValue()
455 return readPrefFromString(mode)
456 }
457
458
459 func errorFromResult(t testing.TB, result interface{}) *operationError {
460 t.Helper()
461
462
463 raw, ok := result.(bson.Raw)
464 if !ok {
465 return nil
466 }
467
468 var expected operationError
469 err := bson.Unmarshal(raw, &expected)
470 if err != nil {
471 return nil
472 }
473 if expected.ErrorCodeName == nil && expected.ErrorContains == nil && len(expected.ErrorLabelsOmit) == 0 &&
474 len(expected.ErrorLabelsContain) == 0 {
475 return nil
476 }
477
478 return &expected
479 }
480
481
482
483 type errorDetails struct {
484 name string
485 labels []string
486 }
487
488
489
490 func extractErrorDetails(err error) (errorDetails, bool) {
491 var details errorDetails
492
493 switch converted := err.(type) {
494 case mongo.CommandError:
495 details.name = converted.Name
496 details.labels = converted.Labels
497 case mongo.WriteException:
498 if converted.WriteConcernError != nil {
499 details.name = converted.WriteConcernError.Name
500 }
501 details.labels = converted.Labels
502 case mongo.BulkWriteException:
503 if converted.WriteConcernError != nil {
504 details.name = converted.WriteConcernError.Name
505 }
506 details.labels = converted.Labels
507 default:
508 return errorDetails{}, false
509 }
510
511 return details, true
512 }
513
514
515 func verifyError(expected *operationError, actual error) error {
516
517
518 if errors.Is(actual, mongo.ErrNoDocuments) || errors.Is(actual, mongo.ErrUnacknowledgedWrite) {
519 actual = nil
520 }
521
522 if expected == nil && actual != nil {
523 return fmt.Errorf("did not expect error but got %w", actual)
524 }
525 if expected != nil && actual == nil {
526 return fmt.Errorf("expected error but got nil")
527 }
528 if expected == nil {
529 return nil
530 }
531
532
533 if expected.ErrorContains != nil {
534 emsg := strings.ToLower(*expected.ErrorContains)
535 amsg := strings.ToLower(actual.Error())
536 if !strings.Contains(amsg, emsg) {
537 return fmt.Errorf("expected error message %q to contain %q", amsg, emsg)
538 }
539 }
540
541
542
543 details, ok := extractErrorDetails(actual)
544 if !ok {
545 if expected.ErrorCodeName != nil || len(expected.ErrorLabelsContain) > 0 || len(expected.ErrorLabelsOmit) > 0 {
546 return fmt.Errorf("failed to extract details from error %v of type %T", actual, actual)
547 }
548 return nil
549 }
550
551 if expected.ErrorCodeName != nil {
552 if *expected.ErrorCodeName != details.name {
553 return fmt.Errorf("expected error name %v, got %v", *expected.ErrorCodeName, details.name)
554 }
555 }
556 for _, label := range expected.ErrorLabelsContain {
557 if !stringSliceContains(details.labels, label) {
558 return fmt.Errorf("expected error %w to contain label %q", actual, label)
559 }
560 }
561 for _, label := range expected.ErrorLabelsOmit {
562 if stringSliceContains(details.labels, label) {
563 return fmt.Errorf("expected error %w to not contain label %q", actual, label)
564 }
565 }
566 return nil
567 }
568
569
570 func getIntFromInterface(i interface{}) *int64 {
571 var out int64
572
573 switch v := i.(type) {
574 case int:
575 out = int64(v)
576 case int32:
577 out = int64(v)
578 case int64:
579 out = v
580 case float32:
581 f := float64(v)
582 if math.Floor(f) != f || f > float64(math.MaxInt64) {
583 break
584 }
585
586 out = int64(f)
587 case float64:
588 if math.Floor(v) != v || v > float64(math.MaxInt64) {
589 break
590 }
591
592 out = int64(v)
593 default:
594 return nil
595 }
596
597 return &out
598 }
599
600 func createCollation(t testing.TB, m bson.Raw) *options.Collation {
601 var collation options.Collation
602 elems, _ := m.Elements()
603
604 for _, elem := range elems {
605 switch elem.Key() {
606 case "locale":
607 collation.Locale = elem.Value().StringValue()
608 case "caseLevel":
609 collation.CaseLevel = elem.Value().Boolean()
610 case "caseFirst":
611 collation.CaseFirst = elem.Value().StringValue()
612 case "strength":
613 collation.Strength = int(elem.Value().Int32())
614 case "numericOrdering":
615 collation.NumericOrdering = elem.Value().Boolean()
616 case "alternate":
617 collation.Alternate = elem.Value().StringValue()
618 case "maxVariable":
619 collation.MaxVariable = elem.Value().StringValue()
620 case "normalization":
621 collation.Normalization = elem.Value().Boolean()
622 case "backwards":
623 collation.Backwards = elem.Value().Boolean()
624 default:
625 t.Fatalf("unrecognized collation option: %v", elem.Key())
626 }
627 }
628 return &collation
629 }
630
631 func convertValueToMilliseconds(t testing.TB, val bson.RawValue) time.Duration {
632 t.Helper()
633
634 int32Val, ok := val.Int32OK()
635 if !ok {
636 t.Fatalf("failed to convert value of type %s to int32", val.Type)
637 }
638 return time.Duration(int32Val) * time.Millisecond
639 }
640
641 func stringSliceContains(stringSlice []string, target string) bool {
642 for _, str := range stringSlice {
643 if str == target {
644 return true
645 }
646 }
647 return false
648 }
649
View as plain text