1
2
3
4
5
6
7 package unified
8
9 import (
10 "bytes"
11 "context"
12 "crypto/tls"
13 "errors"
14 "fmt"
15 "os"
16 "sync"
17 "sync/atomic"
18 "time"
19
20 "go.mongodb.org/mongo-driver/bson"
21 "go.mongodb.org/mongo-driver/mongo"
22 "go.mongodb.org/mongo-driver/mongo/gridfs"
23 "go.mongodb.org/mongo-driver/mongo/options"
24 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
25 )
26
27 var (
28
29 ErrEntityMapOpen = errors.New("slices cannot be accessed while EntityMap is open")
30 )
31
32 var (
33 tlsCAFile = os.Getenv("CSFLE_TLS_CA_FILE")
34 tlsClientCertificateKeyFile = os.Getenv("CSFLE_TLS_CLIENT_CERT_FILE")
35 )
36
37 type storeEventsAsEntitiesConfig struct {
38 EventListID string `bson:"id"`
39 Events []string `bson:"events"`
40 }
41
42 type observeLogMessages struct {
43 Command string `bson:"command"`
44 Topology string `bson:"topology"`
45 ServerSelection string `bson:"serverSelection"`
46 Connection string `bson:"connection"`
47 }
48
49
50
51 type entityOptions struct {
52
53 ID string `bson:"id"`
54
55
56 URIOptions bson.M `bson:"uriOptions"`
57 UseMultipleMongoses *bool `bson:"useMultipleMongoses"`
58 ObserveEvents []string `bson:"observeEvents"`
59 IgnoredCommands []string `bson:"ignoreCommandMonitoringEvents"`
60 ObserveSensitiveCommands *bool `bson:"observeSensitiveCommands"`
61 StoreEventsAsEntities []storeEventsAsEntitiesConfig `bson:"storeEventsAsEntities"`
62 ServerAPIOptions *serverAPIOptions `bson:"serverApi"`
63
64
65 ObserveLogMessages *observeLogMessages `bson:"observeLogMessages"`
66
67
68 DatabaseName string `bson:"databaseName"`
69 DatabaseOptions *dbOrCollectionOptions `bson:"databaseOptions"`
70
71
72 CollectionName string `bson:"collectionName"`
73 CollectionOptions *dbOrCollectionOptions `bson:"collectionOptions"`
74
75
76 SessionOptions *sessionOptions `bson:"sessionOptions"`
77
78
79 GridFSBucketOptions *gridFSBucketOptions `bson:"bucketOptions"`
80
81
82 ClientID string `bson:"client"`
83 DatabaseID string `bson:"database"`
84
85 ClientEncryptionOpts *clientEncryptionOpts `bson:"clientEncryptionOpts"`
86 }
87
88 func (eo *entityOptions) setHeartbeatFrequencyMS(freq time.Duration) {
89 if eo.URIOptions == nil {
90 eo.URIOptions = make(bson.M)
91 }
92
93 if _, ok := eo.URIOptions["heartbeatFrequencyMS"]; !ok {
94
95
96
97 eo.URIOptions["heartbeatFrequencyMS"] = int32(freq.Milliseconds())
98 }
99 }
100
101
102
103 func newCollectionEntityOptions(id string, databaseID string, collectionName string,
104 opts *dbOrCollectionOptions) *entityOptions {
105 options := &entityOptions{
106 ID: id,
107 DatabaseID: databaseID,
108 CollectionName: collectionName,
109 CollectionOptions: opts,
110 }
111
112 return options
113 }
114
115 type task struct {
116 name string
117 execute func() error
118 }
119
120 type backgroundRoutine struct {
121 tasks chan *task
122 wg sync.WaitGroup
123 err error
124 }
125
126 func (b *backgroundRoutine) start() {
127 b.wg.Add(1)
128
129 go func() {
130 defer b.wg.Done()
131
132 for t := range b.tasks {
133 if b.err != nil {
134 continue
135 }
136
137 ch := make(chan error)
138 go func(task *task) {
139 ch <- task.execute()
140 }(t)
141 select {
142 case err := <-ch:
143 if err != nil {
144 b.err = fmt.Errorf("error running operation %s: %v", t.name, err)
145 }
146 case <-time.After(10 * time.Second):
147 b.err = fmt.Errorf("timed out after 10 seconds")
148 }
149 }
150 }()
151 }
152
153 func (b *backgroundRoutine) stop() error {
154 close(b.tasks)
155 b.wg.Wait()
156 return b.err
157 }
158
159 func (b *backgroundRoutine) addTask(name string, execute func() error) bool {
160 select {
161 case b.tasks <- &task{
162 name: name,
163 execute: execute,
164 }:
165 return true
166 default:
167 return false
168 }
169 }
170
171 func newBackgroundRoutine() *backgroundRoutine {
172 routine := &backgroundRoutine{
173 tasks: make(chan *task, 10),
174 }
175
176 return routine
177 }
178
179 type clientEncryptionOpts struct {
180 KeyVaultClient string `bson:"keyVaultClient"`
181 KeyVaultNamespace string `bson:"keyVaultNamespace"`
182 KmsProviders map[string]bson.Raw `bson:"kmsProviders"`
183 }
184
185
186
187
188
189 type EntityMap struct {
190 allEntities map[string]struct{}
191 cursorEntities map[string]cursor
192 clientEntities map[string]*clientEntity
193 dbEntites map[string]*mongo.Database
194 collEntities map[string]*mongo.Collection
195 sessions map[string]mongo.Session
196 gridfsBuckets map[string]*gridfs.Bucket
197 bsonValues map[string]bson.RawValue
198 eventListEntities map[string][]bson.Raw
199 bsonArrayEntities map[string][]bson.Raw
200 successValues map[string]int32
201 iterationValues map[string]int32
202 clientEncryptionEntities map[string]*mongo.ClientEncryption
203 routinesMap sync.Map
204 evtLock sync.Mutex
205 closed atomic.Value
206
207
208
209 keyVaultClientIDs map[string]bool
210 }
211
212 func (em *EntityMap) isClosed() bool {
213 return em.closed.Load().(bool)
214 }
215
216 func (em *EntityMap) setClosed(val bool) {
217 em.closed.Store(val)
218 }
219
220 func newEntityMap() *EntityMap {
221 em := &EntityMap{
222 allEntities: make(map[string]struct{}),
223 gridfsBuckets: make(map[string]*gridfs.Bucket),
224 bsonValues: make(map[string]bson.RawValue),
225 cursorEntities: make(map[string]cursor),
226 clientEntities: make(map[string]*clientEntity),
227 collEntities: make(map[string]*mongo.Collection),
228 dbEntites: make(map[string]*mongo.Database),
229 sessions: make(map[string]mongo.Session),
230 eventListEntities: make(map[string][]bson.Raw),
231 bsonArrayEntities: make(map[string][]bson.Raw),
232 successValues: make(map[string]int32),
233 iterationValues: make(map[string]int32),
234 clientEncryptionEntities: make(map[string]*mongo.ClientEncryption),
235 keyVaultClientIDs: make(map[string]bool),
236 }
237 em.setClosed(false)
238 return em
239 }
240
241 func (em *EntityMap) addBSONEntity(id string, val bson.RawValue) error {
242 if err := em.verifyEntityDoesNotExist(id); err != nil {
243 return err
244 }
245
246 em.allEntities[id] = struct{}{}
247 em.bsonValues[id] = val
248 return nil
249 }
250
251 func (em *EntityMap) addCursorEntity(id string, cursor cursor) error {
252 if err := em.verifyEntityDoesNotExist(id); err != nil {
253 return err
254 }
255
256 em.allEntities[id] = struct{}{}
257 em.cursorEntities[id] = cursor
258 return nil
259 }
260
261 func (em *EntityMap) addBSONArrayEntity(id string) error {
262
263 if _, ok := em.allEntities[id]; ok {
264 if _, ok := em.bsonArrayEntities[id]; !ok {
265 return fmt.Errorf("non-BSON array entity with ID %q already exists", id)
266 }
267 return nil
268 }
269
270 em.allEntities[id] = struct{}{}
271 em.bsonArrayEntities[id] = []bson.Raw{}
272 return nil
273 }
274
275 func (em *EntityMap) addSuccessesEntity(id string) error {
276 if err := em.verifyEntityDoesNotExist(id); err != nil {
277 return err
278 }
279
280 em.allEntities[id] = struct{}{}
281 em.successValues[id] = 0
282 return nil
283 }
284
285 func (em *EntityMap) addIterationsEntity(id string) error {
286 if err := em.verifyEntityDoesNotExist(id); err != nil {
287 return err
288 }
289
290 em.allEntities[id] = struct{}{}
291 em.iterationValues[id] = 0
292 return nil
293 }
294
295 func (em *EntityMap) addEventsEntity(id string) error {
296 if err := em.verifyEntityDoesNotExist(id); err != nil {
297 return err
298 }
299 em.allEntities[id] = struct{}{}
300 em.eventListEntities[id] = []bson.Raw{}
301 return nil
302 }
303
304 func (em *EntityMap) incrementSuccesses(id string) error {
305 if _, ok := em.successValues[id]; !ok {
306 return newEntityNotFoundError("successes", id)
307 }
308 em.successValues[id]++
309 return nil
310 }
311
312 func (em *EntityMap) incrementIterations(id string) error {
313 if _, ok := em.iterationValues[id]; !ok {
314 return newEntityNotFoundError("iterations", id)
315 }
316 em.iterationValues[id]++
317 return nil
318 }
319
320 func (em *EntityMap) appendEventsEntity(id string, doc bson.Raw) {
321 em.evtLock.Lock()
322 defer em.evtLock.Unlock()
323 if _, ok := em.eventListEntities[id]; ok {
324 em.eventListEntities[id] = append(em.eventListEntities[id], doc)
325 }
326 }
327
328 func (em *EntityMap) appendBSONArrayEntity(id string, doc bson.Raw) error {
329 if _, ok := em.bsonArrayEntities[id]; !ok {
330 return newEntityNotFoundError("BSON array", id)
331 }
332 em.bsonArrayEntities[id] = append(em.bsonArrayEntities[id], doc)
333 return nil
334 }
335
336 func (em *EntityMap) addEntity(ctx context.Context, entityType string, entityOptions *entityOptions) error {
337 if err := em.verifyEntityDoesNotExist(entityOptions.ID); err != nil {
338 return err
339 }
340
341 var err error
342 switch entityType {
343 case "client":
344 err = em.addClientEntity(ctx, entityOptions)
345 case "database":
346 err = em.addDatabaseEntity(entityOptions)
347 case "collection":
348 err = em.addCollectionEntity(entityOptions)
349 case "session":
350 err = em.addSessionEntity(entityOptions)
351 case "thread":
352 routine := newBackgroundRoutine()
353 em.routinesMap.Store(entityOptions.ID, routine)
354 routine.start()
355 case "bucket":
356 err = em.addGridFSBucketEntity(entityOptions)
357 case "clientEncryption":
358 err = em.addClientEncryptionEntity(entityOptions)
359 default:
360 return fmt.Errorf("unrecognized entity type %q", entityType)
361 }
362
363 if err != nil {
364 return fmt.Errorf("error constructing entity of type %q: %w", entityType, err)
365 }
366 em.allEntities[entityOptions.ID] = struct{}{}
367 return nil
368 }
369
370 func (em *EntityMap) gridFSBucket(id string) (*gridfs.Bucket, error) {
371 bucket, ok := em.gridfsBuckets[id]
372 if !ok {
373 return nil, newEntityNotFoundError("gridfs bucket", id)
374 }
375 return bucket, nil
376 }
377
378 func (em *EntityMap) cursor(id string) (cursor, error) {
379 cursor, ok := em.cursorEntities[id]
380 if !ok {
381 return nil, newEntityNotFoundError("cursor", id)
382 }
383 return cursor, nil
384 }
385
386 func (em *EntityMap) client(id string) (*clientEntity, error) {
387 client, ok := em.clientEntities[id]
388 if !ok {
389 return nil, newEntityNotFoundError("client", id)
390 }
391 return client, nil
392 }
393
394 func (em *EntityMap) clientEncryption(id string) (*mongo.ClientEncryption, error) {
395 cee, ok := em.clientEncryptionEntities[id]
396 if !ok {
397 return nil, newEntityNotFoundError("client", id)
398 }
399 return cee, nil
400 }
401
402 func (em *EntityMap) clients() map[string]*clientEntity {
403 return em.clientEntities
404 }
405
406 func (em *EntityMap) collections() map[string]*mongo.Collection {
407 return em.collEntities
408 }
409
410 func (em *EntityMap) collection(id string) (*mongo.Collection, error) {
411 coll, ok := em.collEntities[id]
412 if !ok {
413 return nil, newEntityNotFoundError("collection", id)
414 }
415 return coll, nil
416 }
417
418 func (em *EntityMap) database(id string) (*mongo.Database, error) {
419 db, ok := em.dbEntites[id]
420 if !ok {
421 return nil, newEntityNotFoundError("database", id)
422 }
423 return db, nil
424 }
425
426 func (em *EntityMap) session(id string) (mongo.Session, error) {
427 sess, ok := em.sessions[id]
428 if !ok {
429 return nil, newEntityNotFoundError("session", id)
430 }
431 return sess, nil
432 }
433
434
435 func (em *EntityMap) BSONValue(id string) (bson.RawValue, error) {
436 val, ok := em.bsonValues[id]
437 if !ok {
438 return emptyRawValue, newEntityNotFoundError("BSON", id)
439 }
440 return val, nil
441 }
442
443
444
445 func (em *EntityMap) EventList(id string) ([]bson.Raw, error) {
446 if !em.isClosed() {
447 return nil, ErrEntityMapOpen
448 }
449 val, ok := em.eventListEntities[id]
450 if !ok {
451 return nil, newEntityNotFoundError("event list", id)
452 }
453 return val, nil
454 }
455
456
457
458 func (em *EntityMap) BSONArray(id string) ([]bson.Raw, error) {
459 if !em.isClosed() {
460 return nil, ErrEntityMapOpen
461 }
462 val, ok := em.bsonArrayEntities[id]
463 if !ok {
464 return nil, newEntityNotFoundError("BSON array", id)
465 }
466 return val, nil
467 }
468
469
470 func (em *EntityMap) Successes(id string) (int32, error) {
471 val, ok := em.successValues[id]
472 if !ok {
473 return 0, newEntityNotFoundError("successes", id)
474 }
475 return val, nil
476 }
477
478
479 func (em *EntityMap) Iterations(id string) (int32, error) {
480 val, ok := em.iterationValues[id]
481 if !ok {
482 return 0, newEntityNotFoundError("iterations", id)
483 }
484 return val, nil
485 }
486
487
488 func (em *EntityMap) close(ctx context.Context) []error {
489 for _, sess := range em.sessions {
490 sess.EndSession(ctx)
491 }
492
493 var errs []error
494 for id, cursor := range em.cursorEntities {
495 if err := cursor.Close(ctx); err != nil {
496 errs = append(errs, fmt.Errorf("error closing cursor with ID %q: %w", id, err))
497 }
498 }
499
500 for id, client := range em.clientEntities {
501 if ok := em.keyVaultClientIDs[id]; ok {
502
503 continue
504 }
505
506 if err := client.disconnect(ctx); err != nil {
507 errs = append(errs, fmt.Errorf("error closing client with ID %q: %w", id, err))
508 }
509 }
510
511 for id, clientEncryption := range em.clientEncryptionEntities {
512 if err := clientEncryption.Close(ctx); err != nil {
513 errs = append(errs, fmt.Errorf("error closing clientEncryption with ID: %q: %w", id, err))
514 }
515 }
516
517 em.setClosed(true)
518 return errs
519 }
520
521 func (em *EntityMap) addClientEntity(ctx context.Context, entityOptions *entityOptions) error {
522 var client *clientEntity
523
524 for _, eventsAsEntity := range entityOptions.StoreEventsAsEntities {
525 if entityOptions.ID == eventsAsEntity.EventListID {
526 return fmt.Errorf("entity with ID %q already exists", entityOptions.ID)
527 }
528 if err := em.addEventsEntity(eventsAsEntity.EventListID); err != nil {
529 return err
530 }
531 }
532
533 client, err := newClientEntity(ctx, em, entityOptions)
534 if err != nil {
535 return fmt.Errorf("error creating client entity: %w", err)
536 }
537
538 em.clientEntities[entityOptions.ID] = client
539 return nil
540 }
541
542 func (em *EntityMap) addDatabaseEntity(entityOptions *entityOptions) error {
543 client, ok := em.clientEntities[entityOptions.ClientID]
544 if !ok {
545 return newEntityNotFoundError("client", entityOptions.ClientID)
546 }
547
548 dbOpts := options.Database()
549 if entityOptions.DatabaseOptions != nil {
550 dbOpts = entityOptions.DatabaseOptions.DBOptions
551 }
552
553 em.dbEntites[entityOptions.ID] = client.Database(entityOptions.DatabaseName, dbOpts)
554 return nil
555 }
556
557
558
559
560 func getKmsCredential(kmsDocument bson.Raw, credentialName string, envVar string, defaultValue string) (string, error) {
561 credentialVal, err := kmsDocument.LookupErr(credentialName)
562 if errors.Is(err, bsoncore.ErrElementNotFound) {
563 return "", nil
564 }
565 if err != nil {
566 return "", err
567 }
568
569 if str, ok := credentialVal.StringValueOK(); ok {
570 return str, nil
571 }
572
573 var ok bool
574 var doc bson.Raw
575 if doc, ok = credentialVal.DocumentOK(); !ok {
576 return "", fmt.Errorf("expected String or Document for %v, got: %v", credentialName, credentialVal)
577 }
578
579 placeholderDoc := bsoncore.NewDocumentBuilder().AppendInt32("$$placeholder", 1).Build()
580
581
582 if !bytes.Equal(doc, placeholderDoc) {
583 return "", fmt.Errorf("unexpected non-empty document for %v: %v", credentialName, doc)
584 }
585 if envVar == "" {
586 return defaultValue, nil
587 }
588 if os.Getenv(envVar) == "" {
589 if defaultValue != "" {
590 return defaultValue, nil
591 }
592 return "", fmt.Errorf("unable to get environment value for %v. Please set the CSFLE environment variable: %v", credentialName, envVar)
593 }
594 return os.Getenv(envVar), nil
595
596 }
597
598 func (em *EntityMap) addClientEncryptionEntity(entityOptions *entityOptions) error {
599
600 kmsProviders := make(map[string]map[string]interface{})
601 ceo := entityOptions.ClientEncryptionOpts
602 tlsconf := make(map[string]*tls.Config)
603 if aws, ok := ceo.KmsProviders["aws"]; ok {
604 kmsProviders["aws"] = make(map[string]interface{})
605
606 awsSessionToken, err := getKmsCredential(aws, "sessionToken", "CSFLE_AWS_TEMP_SESSION_TOKEN", "")
607 if err != nil {
608 return err
609 }
610 if awsSessionToken != "" {
611
612 kmsProviders["aws"]["sessionToken"] = awsSessionToken
613 awsAccessKeyID, err := getKmsCredential(aws, "accessKeyId", "CSFLE_AWS_TEMP_ACCESS_KEY_ID", "")
614 if err != nil {
615 return err
616 }
617 if awsAccessKeyID != "" {
618 kmsProviders["aws"]["accessKeyId"] = awsAccessKeyID
619 }
620
621 awsSecretAccessKey, err := getKmsCredential(aws, "secretAccessKey", "CSFLE_AWS_TEMP_SECRET_ACCESS_KEY", "")
622 if err != nil {
623 return err
624 }
625 if awsSecretAccessKey != "" {
626 kmsProviders["aws"]["secretAccessKey"] = awsSecretAccessKey
627 }
628 } else {
629 awsAccessKeyID, err := getKmsCredential(aws, "accessKeyId", "FLE_AWS_KEY", "")
630 if err != nil {
631 return err
632 }
633 if awsAccessKeyID != "" {
634 kmsProviders["aws"]["accessKeyId"] = awsAccessKeyID
635 }
636
637 awsSecretAccessKey, err := getKmsCredential(aws, "secretAccessKey", "FLE_AWS_SECRET", "")
638 if err != nil {
639 return err
640 }
641 if awsSecretAccessKey != "" {
642 kmsProviders["aws"]["secretAccessKey"] = awsSecretAccessKey
643 }
644 }
645
646 }
647
648 if azure, ok := ceo.KmsProviders["azure"]; ok {
649 kmsProviders["azure"] = make(map[string]interface{})
650
651 azureTenantID, err := getKmsCredential(azure, "tenantId", "FLE_AZURE_TENANTID", "")
652 if err != nil {
653 return err
654 }
655 if azureTenantID != "" {
656 kmsProviders["azure"]["tenantId"] = azureTenantID
657 }
658
659 azureClientID, err := getKmsCredential(azure, "clientId", "FLE_AZURE_CLIENTID", "")
660 if err != nil {
661 return err
662 }
663 if azureClientID != "" {
664 kmsProviders["azure"]["clientId"] = azureClientID
665 }
666
667 azureClientSecret, err := getKmsCredential(azure, "clientSecret", "FLE_AZURE_CLIENTSECRET", "")
668 if err != nil {
669 return err
670 }
671 if azureClientSecret != "" {
672 kmsProviders["azure"]["clientSecret"] = azureClientSecret
673 }
674 }
675
676 if gcp, ok := ceo.KmsProviders["gcp"]; ok {
677 kmsProviders["gcp"] = make(map[string]interface{})
678
679 gcpEmail, err := getKmsCredential(gcp, "email", "FLE_GCP_EMAIL", "")
680 if err != nil {
681 return err
682 }
683 if gcpEmail != "" {
684 kmsProviders["gcp"]["email"] = gcpEmail
685 }
686
687 gcpPrivateKey, err := getKmsCredential(gcp, "privateKey", "FLE_GCP_PRIVATEKEY", "")
688 if err != nil {
689 return err
690 }
691 if gcpPrivateKey != "" {
692 kmsProviders["gcp"]["privateKey"] = gcpPrivateKey
693 }
694 }
695
696 if kmip, ok := ceo.KmsProviders["kmip"]; ok {
697 kmsProviders["kmip"] = make(map[string]interface{})
698
699 kmipEndpoint, err := getKmsCredential(kmip, "endpoint", "", "localhost:5698")
700 if err != nil {
701 return err
702 }
703
704 if tlsClientCertificateKeyFile != "" && tlsCAFile != "" {
705 cfg, err := options.BuildTLSConfig(map[string]interface{}{
706 "tlsCertificateKeyFile": tlsClientCertificateKeyFile,
707 "tlsCAFile": tlsCAFile,
708 })
709 if err != nil {
710 return fmt.Errorf("error constructing tls config: %w", err)
711 }
712 tlsconf["kmip"] = cfg
713 }
714
715 if kmipEndpoint != "" {
716 kmsProviders["kmip"]["endpoint"] = kmipEndpoint
717 }
718 }
719
720 if local, ok := ceo.KmsProviders["local"]; ok {
721 kmsProviders["local"] = make(map[string]interface{})
722
723 defaultLocalKeyBase64 := "Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
724 localKey, err := getKmsCredential(local, "key", "", defaultLocalKeyBase64)
725 if err != nil {
726 return err
727 }
728 if localKey != "" {
729 kmsProviders["local"]["key"] = localKey
730 }
731 }
732
733 em.keyVaultClientIDs[ceo.KeyVaultClient] = true
734 keyVaultClient, ok := em.clientEntities[ceo.KeyVaultClient]
735 if !ok {
736 return newEntityNotFoundError("client", ceo.KeyVaultClient)
737 }
738
739 ce, err := mongo.NewClientEncryption(
740 keyVaultClient.Client,
741 options.ClientEncryption().
742 SetKeyVaultNamespace(ceo.KeyVaultNamespace).
743 SetTLSConfig(tlsconf).
744 SetKmsProviders(kmsProviders))
745 if err != nil {
746 return err
747 }
748
749 em.clientEncryptionEntities[entityOptions.ID] = ce
750
751 return nil
752 }
753
754 func (em *EntityMap) addCollectionEntity(entityOptions *entityOptions) error {
755 db, ok := em.dbEntites[entityOptions.DatabaseID]
756 if !ok {
757 return newEntityNotFoundError("database", entityOptions.DatabaseID)
758 }
759
760 collOpts := options.Collection()
761 if entityOptions.CollectionOptions != nil {
762 collOpts = entityOptions.CollectionOptions.CollectionOptions
763 }
764
765 em.collEntities[entityOptions.ID] = db.Collection(entityOptions.CollectionName, collOpts)
766 return nil
767 }
768
769 func (em *EntityMap) addSessionEntity(entityOptions *entityOptions) error {
770 client, ok := em.clientEntities[entityOptions.ClientID]
771 if !ok {
772 return newEntityNotFoundError("client", entityOptions.ClientID)
773 }
774
775 sessionOpts := options.Session()
776 if entityOptions.SessionOptions != nil {
777 sessionOpts = entityOptions.SessionOptions.SessionOptions
778 }
779
780 sess, err := client.StartSession(sessionOpts)
781 if err != nil {
782 return fmt.Errorf("error starting session: %w", err)
783 }
784
785 em.sessions[entityOptions.ID] = sess
786 return nil
787 }
788
789 func (em *EntityMap) addGridFSBucketEntity(entityOptions *entityOptions) error {
790 db, ok := em.dbEntites[entityOptions.DatabaseID]
791 if !ok {
792 return newEntityNotFoundError("database", entityOptions.DatabaseID)
793 }
794
795 bucketOpts := options.GridFSBucket()
796 if entityOptions.GridFSBucketOptions != nil {
797 bucketOpts = entityOptions.GridFSBucketOptions.BucketOptions
798 }
799
800 bucket, err := gridfs.NewBucket(db, bucketOpts)
801 if err != nil {
802 return fmt.Errorf("error creating GridFS bucket: %w", err)
803 }
804
805 em.gridfsBuckets[entityOptions.ID] = bucket
806 return nil
807 }
808
809 func (em *EntityMap) verifyEntityDoesNotExist(id string) error {
810 if _, ok := em.allEntities[id]; ok {
811 return fmt.Errorf("entity with ID %q already exists", id)
812 }
813 return nil
814 }
815
816 func newEntityNotFoundError(entityType, entityID string) error {
817 return fmt.Errorf("no %s entity found with ID %q", entityType, entityID)
818 }
819
View as plain text