...

Source file src/go.mongodb.org/mongo-driver/mongo/integration/unified/entity.go

Documentation: go.mongodb.org/mongo-driver/mongo/integration/unified

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package unified
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/tls"
    13  	"errors"
    14  	"fmt"
    15  	"os"
    16  	"sync"
    17  	"sync/atomic"
    18  	"time"
    19  
    20  	"go.mongodb.org/mongo-driver/bson"
    21  	"go.mongodb.org/mongo-driver/mongo"
    22  	"go.mongodb.org/mongo-driver/mongo/gridfs"
    23  	"go.mongodb.org/mongo-driver/mongo/options"
    24  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    25  )
    26  
    27  var (
    28  	// ErrEntityMapOpen is returned when a slice entity is accessed while the EntityMap is open
    29  	ErrEntityMapOpen = errors.New("slices cannot be accessed while EntityMap is open")
    30  )
    31  
    32  var (
    33  	tlsCAFile                   = os.Getenv("CSFLE_TLS_CA_FILE")
    34  	tlsClientCertificateKeyFile = os.Getenv("CSFLE_TLS_CLIENT_CERT_FILE")
    35  )
    36  
    37  type storeEventsAsEntitiesConfig struct {
    38  	EventListID string   `bson:"id"`
    39  	Events      []string `bson:"events"`
    40  }
    41  
    42  type observeLogMessages struct {
    43  	Command         string `bson:"command"`
    44  	Topology        string `bson:"topology"`
    45  	ServerSelection string `bson:"serverSelection"`
    46  	Connection      string `bson:"connection"`
    47  }
    48  
    49  // entityOptions represents all options that can be used to configure an entity. Because there are multiple entity
    50  // types, only a subset of the options that this type contains apply to any given entity.
    51  type entityOptions struct {
    52  	// Options that apply to all entity types.
    53  	ID string `bson:"id"`
    54  
    55  	// Options for client entities.
    56  	URIOptions               bson.M                        `bson:"uriOptions"`
    57  	UseMultipleMongoses      *bool                         `bson:"useMultipleMongoses"`
    58  	ObserveEvents            []string                      `bson:"observeEvents"`
    59  	IgnoredCommands          []string                      `bson:"ignoreCommandMonitoringEvents"`
    60  	ObserveSensitiveCommands *bool                         `bson:"observeSensitiveCommands"`
    61  	StoreEventsAsEntities    []storeEventsAsEntitiesConfig `bson:"storeEventsAsEntities"`
    62  	ServerAPIOptions         *serverAPIOptions             `bson:"serverApi"`
    63  
    64  	// Options for logger entities.
    65  	ObserveLogMessages *observeLogMessages `bson:"observeLogMessages"`
    66  
    67  	// Options for database entities.
    68  	DatabaseName    string                 `bson:"databaseName"`
    69  	DatabaseOptions *dbOrCollectionOptions `bson:"databaseOptions"`
    70  
    71  	// Options for collection entities.
    72  	CollectionName    string                 `bson:"collectionName"`
    73  	CollectionOptions *dbOrCollectionOptions `bson:"collectionOptions"`
    74  
    75  	// Options for session entities.
    76  	SessionOptions *sessionOptions `bson:"sessionOptions"`
    77  
    78  	// Options for GridFS bucket entities.
    79  	GridFSBucketOptions *gridFSBucketOptions `bson:"bucketOptions"`
    80  
    81  	// Options that reference other entities.
    82  	ClientID   string `bson:"client"`
    83  	DatabaseID string `bson:"database"`
    84  
    85  	ClientEncryptionOpts *clientEncryptionOpts `bson:"clientEncryptionOpts"`
    86  }
    87  
    88  func (eo *entityOptions) setHeartbeatFrequencyMS(freq time.Duration) {
    89  	if eo.URIOptions == nil {
    90  		eo.URIOptions = make(bson.M)
    91  	}
    92  
    93  	if _, ok := eo.URIOptions["heartbeatFrequencyMS"]; !ok {
    94  		// The UST values for heartbeatFrequencyMS are given as int32,
    95  		// so we need to cast the frequency as int32 before setting it
    96  		// on the URIOptions map.
    97  		eo.URIOptions["heartbeatFrequencyMS"] = int32(freq.Milliseconds())
    98  	}
    99  }
   100  
   101  // newCollectionEntityOptions constructs an entity options object for a
   102  // collection.
   103  func newCollectionEntityOptions(id string, databaseID string, collectionName string,
   104  	opts *dbOrCollectionOptions) *entityOptions {
   105  	options := &entityOptions{
   106  		ID:                id,
   107  		DatabaseID:        databaseID,
   108  		CollectionName:    collectionName,
   109  		CollectionOptions: opts,
   110  	}
   111  
   112  	return options
   113  }
   114  
   115  type task struct {
   116  	name    string
   117  	execute func() error
   118  }
   119  
   120  type backgroundRoutine struct {
   121  	tasks chan *task
   122  	wg    sync.WaitGroup
   123  	err   error
   124  }
   125  
   126  func (b *backgroundRoutine) start() {
   127  	b.wg.Add(1)
   128  
   129  	go func() {
   130  		defer b.wg.Done()
   131  
   132  		for t := range b.tasks {
   133  			if b.err != nil {
   134  				continue
   135  			}
   136  
   137  			ch := make(chan error)
   138  			go func(task *task) {
   139  				ch <- task.execute()
   140  			}(t)
   141  			select {
   142  			case err := <-ch:
   143  				if err != nil {
   144  					b.err = fmt.Errorf("error running operation %s: %v", t.name, err)
   145  				}
   146  			case <-time.After(10 * time.Second):
   147  				b.err = fmt.Errorf("timed out after 10 seconds")
   148  			}
   149  		}
   150  	}()
   151  }
   152  
   153  func (b *backgroundRoutine) stop() error {
   154  	close(b.tasks)
   155  	b.wg.Wait()
   156  	return b.err
   157  }
   158  
   159  func (b *backgroundRoutine) addTask(name string, execute func() error) bool {
   160  	select {
   161  	case b.tasks <- &task{
   162  		name:    name,
   163  		execute: execute,
   164  	}:
   165  		return true
   166  	default:
   167  		return false
   168  	}
   169  }
   170  
   171  func newBackgroundRoutine() *backgroundRoutine {
   172  	routine := &backgroundRoutine{
   173  		tasks: make(chan *task, 10),
   174  	}
   175  
   176  	return routine
   177  }
   178  
   179  type clientEncryptionOpts struct {
   180  	KeyVaultClient    string              `bson:"keyVaultClient"`
   181  	KeyVaultNamespace string              `bson:"keyVaultNamespace"`
   182  	KmsProviders      map[string]bson.Raw `bson:"kmsProviders"`
   183  }
   184  
   185  // EntityMap is used to store entities during tests. This type enforces uniqueness so no two entities can have the same
   186  // ID, even if they are of different types. It also enforces referential integrity so construction of an entity that
   187  // references another (e.g. a database entity references a client) will fail if the referenced entity does not exist.
   188  // Accessors are available for the BSON entities.
   189  type EntityMap struct {
   190  	allEntities              map[string]struct{}
   191  	cursorEntities           map[string]cursor
   192  	clientEntities           map[string]*clientEntity
   193  	dbEntites                map[string]*mongo.Database
   194  	collEntities             map[string]*mongo.Collection
   195  	sessions                 map[string]mongo.Session
   196  	gridfsBuckets            map[string]*gridfs.Bucket
   197  	bsonValues               map[string]bson.RawValue
   198  	eventListEntities        map[string][]bson.Raw
   199  	bsonArrayEntities        map[string][]bson.Raw // for storing errors and failures from a loop operation
   200  	successValues            map[string]int32
   201  	iterationValues          map[string]int32
   202  	clientEncryptionEntities map[string]*mongo.ClientEncryption
   203  	routinesMap              sync.Map // maps thread name to *backgroundRoutine
   204  	evtLock                  sync.Mutex
   205  	closed                   atomic.Value
   206  	// keyVaultClientIDs tracks IDs of clients used as a keyVaultClient in ClientEncryption objects.
   207  	// ClientEncryption.Close() calls Disconnect on the keyVaultClient.
   208  	// EntityMap.close() must skip calling Disconnect on any client entity referenced in keyVaultClientIDs.
   209  	keyVaultClientIDs map[string]bool
   210  }
   211  
   212  func (em *EntityMap) isClosed() bool {
   213  	return em.closed.Load().(bool)
   214  }
   215  
   216  func (em *EntityMap) setClosed(val bool) {
   217  	em.closed.Store(val)
   218  }
   219  
   220  func newEntityMap() *EntityMap {
   221  	em := &EntityMap{
   222  		allEntities:              make(map[string]struct{}),
   223  		gridfsBuckets:            make(map[string]*gridfs.Bucket),
   224  		bsonValues:               make(map[string]bson.RawValue),
   225  		cursorEntities:           make(map[string]cursor),
   226  		clientEntities:           make(map[string]*clientEntity),
   227  		collEntities:             make(map[string]*mongo.Collection),
   228  		dbEntites:                make(map[string]*mongo.Database),
   229  		sessions:                 make(map[string]mongo.Session),
   230  		eventListEntities:        make(map[string][]bson.Raw),
   231  		bsonArrayEntities:        make(map[string][]bson.Raw),
   232  		successValues:            make(map[string]int32),
   233  		iterationValues:          make(map[string]int32),
   234  		clientEncryptionEntities: make(map[string]*mongo.ClientEncryption),
   235  		keyVaultClientIDs:        make(map[string]bool),
   236  	}
   237  	em.setClosed(false)
   238  	return em
   239  }
   240  
   241  func (em *EntityMap) addBSONEntity(id string, val bson.RawValue) error {
   242  	if err := em.verifyEntityDoesNotExist(id); err != nil {
   243  		return err
   244  	}
   245  
   246  	em.allEntities[id] = struct{}{}
   247  	em.bsonValues[id] = val
   248  	return nil
   249  }
   250  
   251  func (em *EntityMap) addCursorEntity(id string, cursor cursor) error {
   252  	if err := em.verifyEntityDoesNotExist(id); err != nil {
   253  		return err
   254  	}
   255  
   256  	em.allEntities[id] = struct{}{}
   257  	em.cursorEntities[id] = cursor
   258  	return nil
   259  }
   260  
   261  func (em *EntityMap) addBSONArrayEntity(id string) error {
   262  	// Error if a non-BSON array entity exists with the same name
   263  	if _, ok := em.allEntities[id]; ok {
   264  		if _, ok := em.bsonArrayEntities[id]; !ok {
   265  			return fmt.Errorf("non-BSON array entity with ID %q already exists", id)
   266  		}
   267  		return nil
   268  	}
   269  
   270  	em.allEntities[id] = struct{}{}
   271  	em.bsonArrayEntities[id] = []bson.Raw{}
   272  	return nil
   273  }
   274  
   275  func (em *EntityMap) addSuccessesEntity(id string) error {
   276  	if err := em.verifyEntityDoesNotExist(id); err != nil {
   277  		return err
   278  	}
   279  
   280  	em.allEntities[id] = struct{}{}
   281  	em.successValues[id] = 0
   282  	return nil
   283  }
   284  
   285  func (em *EntityMap) addIterationsEntity(id string) error {
   286  	if err := em.verifyEntityDoesNotExist(id); err != nil {
   287  		return err
   288  	}
   289  
   290  	em.allEntities[id] = struct{}{}
   291  	em.iterationValues[id] = 0
   292  	return nil
   293  }
   294  
   295  func (em *EntityMap) addEventsEntity(id string) error {
   296  	if err := em.verifyEntityDoesNotExist(id); err != nil {
   297  		return err
   298  	}
   299  	em.allEntities[id] = struct{}{}
   300  	em.eventListEntities[id] = []bson.Raw{}
   301  	return nil
   302  }
   303  
   304  func (em *EntityMap) incrementSuccesses(id string) error {
   305  	if _, ok := em.successValues[id]; !ok {
   306  		return newEntityNotFoundError("successes", id)
   307  	}
   308  	em.successValues[id]++
   309  	return nil
   310  }
   311  
   312  func (em *EntityMap) incrementIterations(id string) error {
   313  	if _, ok := em.iterationValues[id]; !ok {
   314  		return newEntityNotFoundError("iterations", id)
   315  	}
   316  	em.iterationValues[id]++
   317  	return nil
   318  }
   319  
   320  func (em *EntityMap) appendEventsEntity(id string, doc bson.Raw) {
   321  	em.evtLock.Lock()
   322  	defer em.evtLock.Unlock()
   323  	if _, ok := em.eventListEntities[id]; ok {
   324  		em.eventListEntities[id] = append(em.eventListEntities[id], doc)
   325  	}
   326  }
   327  
   328  func (em *EntityMap) appendBSONArrayEntity(id string, doc bson.Raw) error {
   329  	if _, ok := em.bsonArrayEntities[id]; !ok {
   330  		return newEntityNotFoundError("BSON array", id)
   331  	}
   332  	em.bsonArrayEntities[id] = append(em.bsonArrayEntities[id], doc)
   333  	return nil
   334  }
   335  
   336  func (em *EntityMap) addEntity(ctx context.Context, entityType string, entityOptions *entityOptions) error {
   337  	if err := em.verifyEntityDoesNotExist(entityOptions.ID); err != nil {
   338  		return err
   339  	}
   340  
   341  	var err error
   342  	switch entityType {
   343  	case "client":
   344  		err = em.addClientEntity(ctx, entityOptions)
   345  	case "database":
   346  		err = em.addDatabaseEntity(entityOptions)
   347  	case "collection":
   348  		err = em.addCollectionEntity(entityOptions)
   349  	case "session":
   350  		err = em.addSessionEntity(entityOptions)
   351  	case "thread":
   352  		routine := newBackgroundRoutine()
   353  		em.routinesMap.Store(entityOptions.ID, routine)
   354  		routine.start()
   355  	case "bucket":
   356  		err = em.addGridFSBucketEntity(entityOptions)
   357  	case "clientEncryption":
   358  		err = em.addClientEncryptionEntity(entityOptions)
   359  	default:
   360  		return fmt.Errorf("unrecognized entity type %q", entityType)
   361  	}
   362  
   363  	if err != nil {
   364  		return fmt.Errorf("error constructing entity of type %q: %w", entityType, err)
   365  	}
   366  	em.allEntities[entityOptions.ID] = struct{}{}
   367  	return nil
   368  }
   369  
   370  func (em *EntityMap) gridFSBucket(id string) (*gridfs.Bucket, error) {
   371  	bucket, ok := em.gridfsBuckets[id]
   372  	if !ok {
   373  		return nil, newEntityNotFoundError("gridfs bucket", id)
   374  	}
   375  	return bucket, nil
   376  }
   377  
   378  func (em *EntityMap) cursor(id string) (cursor, error) {
   379  	cursor, ok := em.cursorEntities[id]
   380  	if !ok {
   381  		return nil, newEntityNotFoundError("cursor", id)
   382  	}
   383  	return cursor, nil
   384  }
   385  
   386  func (em *EntityMap) client(id string) (*clientEntity, error) {
   387  	client, ok := em.clientEntities[id]
   388  	if !ok {
   389  		return nil, newEntityNotFoundError("client", id)
   390  	}
   391  	return client, nil
   392  }
   393  
   394  func (em *EntityMap) clientEncryption(id string) (*mongo.ClientEncryption, error) {
   395  	cee, ok := em.clientEncryptionEntities[id]
   396  	if !ok {
   397  		return nil, newEntityNotFoundError("client", id)
   398  	}
   399  	return cee, nil
   400  }
   401  
   402  func (em *EntityMap) clients() map[string]*clientEntity {
   403  	return em.clientEntities
   404  }
   405  
   406  func (em *EntityMap) collections() map[string]*mongo.Collection {
   407  	return em.collEntities
   408  }
   409  
   410  func (em *EntityMap) collection(id string) (*mongo.Collection, error) {
   411  	coll, ok := em.collEntities[id]
   412  	if !ok {
   413  		return nil, newEntityNotFoundError("collection", id)
   414  	}
   415  	return coll, nil
   416  }
   417  
   418  func (em *EntityMap) database(id string) (*mongo.Database, error) {
   419  	db, ok := em.dbEntites[id]
   420  	if !ok {
   421  		return nil, newEntityNotFoundError("database", id)
   422  	}
   423  	return db, nil
   424  }
   425  
   426  func (em *EntityMap) session(id string) (mongo.Session, error) {
   427  	sess, ok := em.sessions[id]
   428  	if !ok {
   429  		return nil, newEntityNotFoundError("session", id)
   430  	}
   431  	return sess, nil
   432  }
   433  
   434  // BSONValue returns the bson.RawValue associated with id
   435  func (em *EntityMap) BSONValue(id string) (bson.RawValue, error) {
   436  	val, ok := em.bsonValues[id]
   437  	if !ok {
   438  		return emptyRawValue, newEntityNotFoundError("BSON", id)
   439  	}
   440  	return val, nil
   441  }
   442  
   443  // EventList returns the array of event documents associated with id. This should only be accessed
   444  // after the test is finished running
   445  func (em *EntityMap) EventList(id string) ([]bson.Raw, error) {
   446  	if !em.isClosed() {
   447  		return nil, ErrEntityMapOpen
   448  	}
   449  	val, ok := em.eventListEntities[id]
   450  	if !ok {
   451  		return nil, newEntityNotFoundError("event list", id)
   452  	}
   453  	return val, nil
   454  }
   455  
   456  // BSONArray returns the BSON document array associated with id. This should only be accessed
   457  // after the test is finished running
   458  func (em *EntityMap) BSONArray(id string) ([]bson.Raw, error) {
   459  	if !em.isClosed() {
   460  		return nil, ErrEntityMapOpen
   461  	}
   462  	val, ok := em.bsonArrayEntities[id]
   463  	if !ok {
   464  		return nil, newEntityNotFoundError("BSON array", id)
   465  	}
   466  	return val, nil
   467  }
   468  
   469  // Successes returns the number of successes associated with id
   470  func (em *EntityMap) Successes(id string) (int32, error) {
   471  	val, ok := em.successValues[id]
   472  	if !ok {
   473  		return 0, newEntityNotFoundError("successes", id)
   474  	}
   475  	return val, nil
   476  }
   477  
   478  // Iterations returns the number of iterations associated with id
   479  func (em *EntityMap) Iterations(id string) (int32, error) {
   480  	val, ok := em.iterationValues[id]
   481  	if !ok {
   482  		return 0, newEntityNotFoundError("iterations", id)
   483  	}
   484  	return val, nil
   485  }
   486  
   487  // close disposes of the session and client entities associated with this map.
   488  func (em *EntityMap) close(ctx context.Context) []error {
   489  	for _, sess := range em.sessions {
   490  		sess.EndSession(ctx)
   491  	}
   492  
   493  	var errs []error
   494  	for id, cursor := range em.cursorEntities {
   495  		if err := cursor.Close(ctx); err != nil {
   496  			errs = append(errs, fmt.Errorf("error closing cursor with ID %q: %w", id, err))
   497  		}
   498  	}
   499  
   500  	for id, client := range em.clientEntities {
   501  		if ok := em.keyVaultClientIDs[id]; ok {
   502  			// Client will be closed in clientEncryption.Close()
   503  			continue
   504  		}
   505  
   506  		if err := client.disconnect(ctx); err != nil {
   507  			errs = append(errs, fmt.Errorf("error closing client with ID %q: %w", id, err))
   508  		}
   509  	}
   510  
   511  	for id, clientEncryption := range em.clientEncryptionEntities {
   512  		if err := clientEncryption.Close(ctx); err != nil {
   513  			errs = append(errs, fmt.Errorf("error closing clientEncryption with ID: %q: %w", id, err))
   514  		}
   515  	}
   516  
   517  	em.setClosed(true)
   518  	return errs
   519  }
   520  
   521  func (em *EntityMap) addClientEntity(ctx context.Context, entityOptions *entityOptions) error {
   522  	var client *clientEntity
   523  
   524  	for _, eventsAsEntity := range entityOptions.StoreEventsAsEntities {
   525  		if entityOptions.ID == eventsAsEntity.EventListID {
   526  			return fmt.Errorf("entity with ID %q already exists", entityOptions.ID)
   527  		}
   528  		if err := em.addEventsEntity(eventsAsEntity.EventListID); err != nil {
   529  			return err
   530  		}
   531  	}
   532  
   533  	client, err := newClientEntity(ctx, em, entityOptions)
   534  	if err != nil {
   535  		return fmt.Errorf("error creating client entity: %w", err)
   536  	}
   537  
   538  	em.clientEntities[entityOptions.ID] = client
   539  	return nil
   540  }
   541  
   542  func (em *EntityMap) addDatabaseEntity(entityOptions *entityOptions) error {
   543  	client, ok := em.clientEntities[entityOptions.ClientID]
   544  	if !ok {
   545  		return newEntityNotFoundError("client", entityOptions.ClientID)
   546  	}
   547  
   548  	dbOpts := options.Database()
   549  	if entityOptions.DatabaseOptions != nil {
   550  		dbOpts = entityOptions.DatabaseOptions.DBOptions
   551  	}
   552  
   553  	em.dbEntites[entityOptions.ID] = client.Database(entityOptions.DatabaseName, dbOpts)
   554  	return nil
   555  }
   556  
   557  // getKmsCredential processes a value of an input KMS provider credential.
   558  // An empty document returns from the environment.
   559  // A string is returned as-is.
   560  func getKmsCredential(kmsDocument bson.Raw, credentialName string, envVar string, defaultValue string) (string, error) {
   561  	credentialVal, err := kmsDocument.LookupErr(credentialName)
   562  	if errors.Is(err, bsoncore.ErrElementNotFound) {
   563  		return "", nil
   564  	}
   565  	if err != nil {
   566  		return "", err
   567  	}
   568  
   569  	if str, ok := credentialVal.StringValueOK(); ok {
   570  		return str, nil
   571  	}
   572  
   573  	var ok bool
   574  	var doc bson.Raw
   575  	if doc, ok = credentialVal.DocumentOK(); !ok {
   576  		return "", fmt.Errorf("expected String or Document for %v, got: %v", credentialName, credentialVal)
   577  	}
   578  
   579  	placeholderDoc := bsoncore.NewDocumentBuilder().AppendInt32("$$placeholder", 1).Build()
   580  
   581  	// Check if document is a placeholder.
   582  	if !bytes.Equal(doc, placeholderDoc) {
   583  		return "", fmt.Errorf("unexpected non-empty document for %v: %v", credentialName, doc)
   584  	}
   585  	if envVar == "" {
   586  		return defaultValue, nil
   587  	}
   588  	if os.Getenv(envVar) == "" {
   589  		if defaultValue != "" {
   590  			return defaultValue, nil
   591  		}
   592  		return "", fmt.Errorf("unable to get environment value for %v. Please set the CSFLE environment variable: %v", credentialName, envVar)
   593  	}
   594  	return os.Getenv(envVar), nil
   595  
   596  }
   597  
   598  func (em *EntityMap) addClientEncryptionEntity(entityOptions *entityOptions) error {
   599  	// Construct KMS providers.
   600  	kmsProviders := make(map[string]map[string]interface{})
   601  	ceo := entityOptions.ClientEncryptionOpts
   602  	tlsconf := make(map[string]*tls.Config)
   603  	if aws, ok := ceo.KmsProviders["aws"]; ok {
   604  		kmsProviders["aws"] = make(map[string]interface{})
   605  
   606  		awsSessionToken, err := getKmsCredential(aws, "sessionToken", "CSFLE_AWS_TEMP_SESSION_TOKEN", "")
   607  		if err != nil {
   608  			return err
   609  		}
   610  		if awsSessionToken != "" {
   611  			// Get temporary AWS credentials.
   612  			kmsProviders["aws"]["sessionToken"] = awsSessionToken
   613  			awsAccessKeyID, err := getKmsCredential(aws, "accessKeyId", "CSFLE_AWS_TEMP_ACCESS_KEY_ID", "")
   614  			if err != nil {
   615  				return err
   616  			}
   617  			if awsAccessKeyID != "" {
   618  				kmsProviders["aws"]["accessKeyId"] = awsAccessKeyID
   619  			}
   620  
   621  			awsSecretAccessKey, err := getKmsCredential(aws, "secretAccessKey", "CSFLE_AWS_TEMP_SECRET_ACCESS_KEY", "")
   622  			if err != nil {
   623  				return err
   624  			}
   625  			if awsSecretAccessKey != "" {
   626  				kmsProviders["aws"]["secretAccessKey"] = awsSecretAccessKey
   627  			}
   628  		} else {
   629  			awsAccessKeyID, err := getKmsCredential(aws, "accessKeyId", "FLE_AWS_KEY", "")
   630  			if err != nil {
   631  				return err
   632  			}
   633  			if awsAccessKeyID != "" {
   634  				kmsProviders["aws"]["accessKeyId"] = awsAccessKeyID
   635  			}
   636  
   637  			awsSecretAccessKey, err := getKmsCredential(aws, "secretAccessKey", "FLE_AWS_SECRET", "")
   638  			if err != nil {
   639  				return err
   640  			}
   641  			if awsSecretAccessKey != "" {
   642  				kmsProviders["aws"]["secretAccessKey"] = awsSecretAccessKey
   643  			}
   644  		}
   645  
   646  	}
   647  
   648  	if azure, ok := ceo.KmsProviders["azure"]; ok {
   649  		kmsProviders["azure"] = make(map[string]interface{})
   650  
   651  		azureTenantID, err := getKmsCredential(azure, "tenantId", "FLE_AZURE_TENANTID", "")
   652  		if err != nil {
   653  			return err
   654  		}
   655  		if azureTenantID != "" {
   656  			kmsProviders["azure"]["tenantId"] = azureTenantID
   657  		}
   658  
   659  		azureClientID, err := getKmsCredential(azure, "clientId", "FLE_AZURE_CLIENTID", "")
   660  		if err != nil {
   661  			return err
   662  		}
   663  		if azureClientID != "" {
   664  			kmsProviders["azure"]["clientId"] = azureClientID
   665  		}
   666  
   667  		azureClientSecret, err := getKmsCredential(azure, "clientSecret", "FLE_AZURE_CLIENTSECRET", "")
   668  		if err != nil {
   669  			return err
   670  		}
   671  		if azureClientSecret != "" {
   672  			kmsProviders["azure"]["clientSecret"] = azureClientSecret
   673  		}
   674  	}
   675  
   676  	if gcp, ok := ceo.KmsProviders["gcp"]; ok {
   677  		kmsProviders["gcp"] = make(map[string]interface{})
   678  
   679  		gcpEmail, err := getKmsCredential(gcp, "email", "FLE_GCP_EMAIL", "")
   680  		if err != nil {
   681  			return err
   682  		}
   683  		if gcpEmail != "" {
   684  			kmsProviders["gcp"]["email"] = gcpEmail
   685  		}
   686  
   687  		gcpPrivateKey, err := getKmsCredential(gcp, "privateKey", "FLE_GCP_PRIVATEKEY", "")
   688  		if err != nil {
   689  			return err
   690  		}
   691  		if gcpPrivateKey != "" {
   692  			kmsProviders["gcp"]["privateKey"] = gcpPrivateKey
   693  		}
   694  	}
   695  
   696  	if kmip, ok := ceo.KmsProviders["kmip"]; ok {
   697  		kmsProviders["kmip"] = make(map[string]interface{})
   698  
   699  		kmipEndpoint, err := getKmsCredential(kmip, "endpoint", "", "localhost:5698")
   700  		if err != nil {
   701  			return err
   702  		}
   703  
   704  		if tlsClientCertificateKeyFile != "" && tlsCAFile != "" {
   705  			cfg, err := options.BuildTLSConfig(map[string]interface{}{
   706  				"tlsCertificateKeyFile": tlsClientCertificateKeyFile,
   707  				"tlsCAFile":             tlsCAFile,
   708  			})
   709  			if err != nil {
   710  				return fmt.Errorf("error constructing tls config: %w", err)
   711  			}
   712  			tlsconf["kmip"] = cfg
   713  		}
   714  
   715  		if kmipEndpoint != "" {
   716  			kmsProviders["kmip"]["endpoint"] = kmipEndpoint
   717  		}
   718  	}
   719  
   720  	if local, ok := ceo.KmsProviders["local"]; ok {
   721  		kmsProviders["local"] = make(map[string]interface{})
   722  
   723  		defaultLocalKeyBase64 := "Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
   724  		localKey, err := getKmsCredential(local, "key", "", defaultLocalKeyBase64)
   725  		if err != nil {
   726  			return err
   727  		}
   728  		if localKey != "" {
   729  			kmsProviders["local"]["key"] = localKey
   730  		}
   731  	}
   732  
   733  	em.keyVaultClientIDs[ceo.KeyVaultClient] = true
   734  	keyVaultClient, ok := em.clientEntities[ceo.KeyVaultClient]
   735  	if !ok {
   736  		return newEntityNotFoundError("client", ceo.KeyVaultClient)
   737  	}
   738  
   739  	ce, err := mongo.NewClientEncryption(
   740  		keyVaultClient.Client,
   741  		options.ClientEncryption().
   742  			SetKeyVaultNamespace(ceo.KeyVaultNamespace).
   743  			SetTLSConfig(tlsconf).
   744  			SetKmsProviders(kmsProviders))
   745  	if err != nil {
   746  		return err
   747  	}
   748  
   749  	em.clientEncryptionEntities[entityOptions.ID] = ce
   750  
   751  	return nil
   752  }
   753  
   754  func (em *EntityMap) addCollectionEntity(entityOptions *entityOptions) error {
   755  	db, ok := em.dbEntites[entityOptions.DatabaseID]
   756  	if !ok {
   757  		return newEntityNotFoundError("database", entityOptions.DatabaseID)
   758  	}
   759  
   760  	collOpts := options.Collection()
   761  	if entityOptions.CollectionOptions != nil {
   762  		collOpts = entityOptions.CollectionOptions.CollectionOptions
   763  	}
   764  
   765  	em.collEntities[entityOptions.ID] = db.Collection(entityOptions.CollectionName, collOpts)
   766  	return nil
   767  }
   768  
   769  func (em *EntityMap) addSessionEntity(entityOptions *entityOptions) error {
   770  	client, ok := em.clientEntities[entityOptions.ClientID]
   771  	if !ok {
   772  		return newEntityNotFoundError("client", entityOptions.ClientID)
   773  	}
   774  
   775  	sessionOpts := options.Session()
   776  	if entityOptions.SessionOptions != nil {
   777  		sessionOpts = entityOptions.SessionOptions.SessionOptions
   778  	}
   779  
   780  	sess, err := client.StartSession(sessionOpts)
   781  	if err != nil {
   782  		return fmt.Errorf("error starting session: %w", err)
   783  	}
   784  
   785  	em.sessions[entityOptions.ID] = sess
   786  	return nil
   787  }
   788  
   789  func (em *EntityMap) addGridFSBucketEntity(entityOptions *entityOptions) error {
   790  	db, ok := em.dbEntites[entityOptions.DatabaseID]
   791  	if !ok {
   792  		return newEntityNotFoundError("database", entityOptions.DatabaseID)
   793  	}
   794  
   795  	bucketOpts := options.GridFSBucket()
   796  	if entityOptions.GridFSBucketOptions != nil {
   797  		bucketOpts = entityOptions.GridFSBucketOptions.BucketOptions
   798  	}
   799  
   800  	bucket, err := gridfs.NewBucket(db, bucketOpts)
   801  	if err != nil {
   802  		return fmt.Errorf("error creating GridFS bucket: %w", err)
   803  	}
   804  
   805  	em.gridfsBuckets[entityOptions.ID] = bucket
   806  	return nil
   807  }
   808  
   809  func (em *EntityMap) verifyEntityDoesNotExist(id string) error {
   810  	if _, ok := em.allEntities[id]; ok {
   811  		return fmt.Errorf("entity with ID %q already exists", id)
   812  	}
   813  	return nil
   814  }
   815  
   816  func newEntityNotFoundError(entityType, entityID string) error {
   817  	return fmt.Errorf("no %s entity found with ID %q", entityType, entityID)
   818  }
   819  

View as plain text