...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/crypt.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package driver
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"strings"
    16  	"time"
    17  
    18  	"go.mongodb.org/mongo-driver/bson/bsontype"
    19  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    20  	"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
    21  	"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options"
    22  )
    23  
    24  const (
    25  	defaultKmsPort    = 443
    26  	defaultKmsTimeout = 10 * time.Second
    27  )
    28  
    29  // CollectionInfoFn is a callback used to retrieve collection information.
    30  type CollectionInfoFn func(ctx context.Context, db string, filter bsoncore.Document) (bsoncore.Document, error)
    31  
    32  // KeyRetrieverFn is a callback used to retrieve keys from the key vault.
    33  type KeyRetrieverFn func(ctx context.Context, filter bsoncore.Document) ([]bsoncore.Document, error)
    34  
    35  // MarkCommandFn is a callback used to add encryption markings to a command.
    36  type MarkCommandFn func(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
    37  
    38  // CryptOptions specifies options to configure a Crypt instance.
    39  type CryptOptions struct {
    40  	MongoCrypt           *mongocrypt.MongoCrypt
    41  	CollInfoFn           CollectionInfoFn
    42  	KeyFn                KeyRetrieverFn
    43  	MarkFn               MarkCommandFn
    44  	TLSConfig            map[string]*tls.Config
    45  	BypassAutoEncryption bool
    46  	BypassQueryAnalysis  bool
    47  }
    48  
    49  // Crypt is an interface implemented by types that can encrypt and decrypt instances of
    50  // bsoncore.Document.
    51  //
    52  // Users should rely on the driver's crypt type (used by default) for encryption and decryption
    53  // unless they are perfectly confident in another implementation of Crypt.
    54  type Crypt interface {
    55  	// Encrypt encrypts the given command.
    56  	Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
    57  	// Decrypt decrypts the given command response.
    58  	Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error)
    59  	// CreateDataKey creates a data key using the given KMS provider and options.
    60  	CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error)
    61  	// EncryptExplicit encrypts the given value with the given options.
    62  	EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error)
    63  	// EncryptExplicitExpression encrypts the given expression with the given options.
    64  	EncryptExplicitExpression(ctx context.Context, val bsoncore.Document, opts *options.ExplicitEncryptionOptions) (bsoncore.Document, error)
    65  	// DecryptExplicit decrypts the given encrypted value.
    66  	DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error)
    67  	// Close cleans up any resources associated with the Crypt instance.
    68  	Close()
    69  	// BypassAutoEncryption returns true if auto-encryption should be bypassed.
    70  	BypassAutoEncryption() bool
    71  	// RewrapDataKey attempts to rewrap the document data keys matching the filter, preparing the re-wrapped documents
    72  	// to be returned as a slice of bsoncore.Document.
    73  	RewrapDataKey(ctx context.Context, filter []byte, opts *options.RewrapManyDataKeyOptions) ([]bsoncore.Document, error)
    74  }
    75  
    76  // crypt consumes the libmongocrypt.MongoCrypt type to iterate the mongocrypt state machine and perform encryption
    77  // and decryption.
    78  type crypt struct {
    79  	mongoCrypt *mongocrypt.MongoCrypt
    80  	collInfoFn CollectionInfoFn
    81  	keyFn      KeyRetrieverFn
    82  	markFn     MarkCommandFn
    83  	tlsConfig  map[string]*tls.Config
    84  
    85  	bypassAutoEncryption bool
    86  }
    87  
    88  // NewCrypt creates a new Crypt instance configured with the given AutoEncryptionOptions.
    89  func NewCrypt(opts *CryptOptions) Crypt {
    90  	c := &crypt{
    91  		mongoCrypt:           opts.MongoCrypt,
    92  		collInfoFn:           opts.CollInfoFn,
    93  		keyFn:                opts.KeyFn,
    94  		markFn:               opts.MarkFn,
    95  		tlsConfig:            opts.TLSConfig,
    96  		bypassAutoEncryption: opts.BypassAutoEncryption,
    97  	}
    98  	return c
    99  }
   100  
   101  // Encrypt encrypts the given command.
   102  func (c *crypt) Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error) {
   103  	if c.bypassAutoEncryption {
   104  		return cmd, nil
   105  	}
   106  
   107  	cryptCtx, err := c.mongoCrypt.CreateEncryptionContext(db, cmd)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	defer cryptCtx.Close()
   112  
   113  	return c.executeStateMachine(ctx, cryptCtx, db)
   114  }
   115  
   116  // Decrypt decrypts the given command response.
   117  func (c *crypt) Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error) {
   118  	cryptCtx, err := c.mongoCrypt.CreateDecryptionContext(cmdResponse)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	defer cryptCtx.Close()
   123  
   124  	return c.executeStateMachine(ctx, cryptCtx, "")
   125  }
   126  
   127  // CreateDataKey creates a data key using the given KMS provider and options.
   128  func (c *crypt) CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error) {
   129  	cryptCtx, err := c.mongoCrypt.CreateDataKeyContext(kmsProvider, opts)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	defer cryptCtx.Close()
   134  
   135  	return c.executeStateMachine(ctx, cryptCtx, "")
   136  }
   137  
   138  // RewrapDataKey attempts to rewrap the document data keys matching the filter, preparing the re-wrapped documents to
   139  // be returned as a slice of bsoncore.Document.
   140  func (c *crypt) RewrapDataKey(ctx context.Context, filter []byte,
   141  	opts *options.RewrapManyDataKeyOptions) ([]bsoncore.Document, error) {
   142  
   143  	cryptCtx, err := c.mongoCrypt.RewrapDataKeyContext(filter, opts)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	defer cryptCtx.Close()
   148  
   149  	rewrappedBSON, err := c.executeStateMachine(ctx, cryptCtx, "")
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  	if rewrappedBSON == nil {
   154  		return nil, nil
   155  	}
   156  
   157  	// mongocrypt_ctx_rewrap_many_datakey_init wraps the documents in a BSON of the form { "v": [(BSON document), ...] }
   158  	// where each BSON document in the slice is a document containing a rewrapped datakey.
   159  	rewrappedDocumentBytes, err := rewrappedBSON.LookupErr("v")
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	// Parse the resulting BSON as individual documents.
   165  	rewrappedDocsArray, ok := rewrappedDocumentBytes.ArrayOK()
   166  	if !ok {
   167  		return nil, fmt.Errorf("expected results from mongocrypt_ctx_rewrap_many_datakey_init to be an array")
   168  	}
   169  
   170  	rewrappedDocumentValues, err := rewrappedDocsArray.Values()
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  
   175  	rewrappedDocuments := []bsoncore.Document{}
   176  	for _, rewrappedDocumentValue := range rewrappedDocumentValues {
   177  		if rewrappedDocumentValue.Type != bsontype.EmbeddedDocument {
   178  			// If a value in the document's array returned by mongocrypt is anything other than an embedded document,
   179  			// then something is wrong and we should terminate the routine.
   180  			return nil, fmt.Errorf("expected value of type %q, got: %q",
   181  				bsontype.EmbeddedDocument.String(),
   182  				rewrappedDocumentValue.Type.String())
   183  		}
   184  		rewrappedDocuments = append(rewrappedDocuments, rewrappedDocumentValue.Document())
   185  	}
   186  	return rewrappedDocuments, nil
   187  }
   188  
   189  // EncryptExplicit encrypts the given value with the given options.
   190  func (c *crypt) EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error) {
   191  	idx, doc := bsoncore.AppendDocumentStart(nil)
   192  	doc = bsoncore.AppendValueElement(doc, "v", val)
   193  	doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
   194  
   195  	cryptCtx, err := c.mongoCrypt.CreateExplicitEncryptionContext(doc, opts)
   196  	if err != nil {
   197  		return 0, nil, err
   198  	}
   199  	defer cryptCtx.Close()
   200  
   201  	res, err := c.executeStateMachine(ctx, cryptCtx, "")
   202  	if err != nil {
   203  		return 0, nil, err
   204  	}
   205  
   206  	sub, data := res.Lookup("v").Binary()
   207  	return sub, data, nil
   208  }
   209  
   210  // EncryptExplicitExpression encrypts the given expression with the given options.
   211  func (c *crypt) EncryptExplicitExpression(ctx context.Context, expr bsoncore.Document, opts *options.ExplicitEncryptionOptions) (bsoncore.Document, error) {
   212  	idx, doc := bsoncore.AppendDocumentStart(nil)
   213  	doc = bsoncore.AppendDocumentElement(doc, "v", expr)
   214  	doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
   215  
   216  	cryptCtx, err := c.mongoCrypt.CreateExplicitEncryptionExpressionContext(doc, opts)
   217  	if err != nil {
   218  		return nil, err
   219  	}
   220  	defer cryptCtx.Close()
   221  
   222  	res, err := c.executeStateMachine(ctx, cryptCtx, "")
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  
   227  	encryptedExpr := res.Lookup("v").Document()
   228  	return encryptedExpr, nil
   229  }
   230  
   231  // DecryptExplicit decrypts the given encrypted value.
   232  func (c *crypt) DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error) {
   233  	idx, doc := bsoncore.AppendDocumentStart(nil)
   234  	doc = bsoncore.AppendBinaryElement(doc, "v", subtype, data)
   235  	doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
   236  
   237  	cryptCtx, err := c.mongoCrypt.CreateExplicitDecryptionContext(doc)
   238  	if err != nil {
   239  		return bsoncore.Value{}, err
   240  	}
   241  	defer cryptCtx.Close()
   242  
   243  	res, err := c.executeStateMachine(ctx, cryptCtx, "")
   244  	if err != nil {
   245  		return bsoncore.Value{}, err
   246  	}
   247  
   248  	return res.Lookup("v"), nil
   249  }
   250  
   251  // Close cleans up any resources associated with the Crypt instance.
   252  func (c *crypt) Close() {
   253  	c.mongoCrypt.Close()
   254  }
   255  
   256  func (c *crypt) BypassAutoEncryption() bool {
   257  	return c.bypassAutoEncryption
   258  }
   259  
   260  func (c *crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Context, db string) (bsoncore.Document, error) {
   261  	var err error
   262  	for {
   263  		state := cryptCtx.State()
   264  		switch state {
   265  		case mongocrypt.NeedMongoCollInfo:
   266  			err = c.collectionInfo(ctx, cryptCtx, db)
   267  		case mongocrypt.NeedMongoMarkings:
   268  			err = c.markCommand(ctx, cryptCtx, db)
   269  		case mongocrypt.NeedMongoKeys:
   270  			err = c.retrieveKeys(ctx, cryptCtx)
   271  		case mongocrypt.NeedKms:
   272  			err = c.decryptKeys(cryptCtx)
   273  		case mongocrypt.Ready:
   274  			return cryptCtx.Finish()
   275  		case mongocrypt.Done:
   276  			return nil, nil
   277  		case mongocrypt.NeedKmsCredentials:
   278  			err = c.provideKmsProviders(ctx, cryptCtx)
   279  		default:
   280  			return nil, fmt.Errorf("invalid Crypt state: %v", state)
   281  		}
   282  		if err != nil {
   283  			return nil, err
   284  		}
   285  	}
   286  }
   287  
   288  func (c *crypt) collectionInfo(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
   289  	op, err := cryptCtx.NextOperation()
   290  	if err != nil {
   291  		return err
   292  	}
   293  
   294  	collInfo, err := c.collInfoFn(ctx, db, op)
   295  	if err != nil {
   296  		return err
   297  	}
   298  	if collInfo != nil {
   299  		if err = cryptCtx.AddOperationResult(collInfo); err != nil {
   300  			return err
   301  		}
   302  	}
   303  
   304  	return cryptCtx.CompleteOperation()
   305  }
   306  
   307  func (c *crypt) markCommand(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
   308  	op, err := cryptCtx.NextOperation()
   309  	if err != nil {
   310  		return err
   311  	}
   312  
   313  	markedCmd, err := c.markFn(ctx, db, op)
   314  	if err != nil {
   315  		return err
   316  	}
   317  	if err = cryptCtx.AddOperationResult(markedCmd); err != nil {
   318  		return err
   319  	}
   320  
   321  	return cryptCtx.CompleteOperation()
   322  }
   323  
   324  func (c *crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context) error {
   325  	op, err := cryptCtx.NextOperation()
   326  	if err != nil {
   327  		return err
   328  	}
   329  
   330  	keys, err := c.keyFn(ctx, op)
   331  	if err != nil {
   332  		return err
   333  	}
   334  
   335  	for _, key := range keys {
   336  		if err = cryptCtx.AddOperationResult(key); err != nil {
   337  			return err
   338  		}
   339  	}
   340  
   341  	return cryptCtx.CompleteOperation()
   342  }
   343  
   344  func (c *crypt) decryptKeys(cryptCtx *mongocrypt.Context) error {
   345  	for {
   346  		kmsCtx := cryptCtx.NextKmsContext()
   347  		if kmsCtx == nil {
   348  			break
   349  		}
   350  
   351  		if err := c.decryptKey(kmsCtx); err != nil {
   352  			return err
   353  		}
   354  	}
   355  
   356  	return cryptCtx.FinishKmsContexts()
   357  }
   358  
   359  func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
   360  	host, err := kmsCtx.HostName()
   361  	if err != nil {
   362  		return err
   363  	}
   364  	msg, err := kmsCtx.Message()
   365  	if err != nil {
   366  		return err
   367  	}
   368  
   369  	// add a port to the address if it's not already present
   370  	addr := host
   371  	if idx := strings.IndexByte(host, ':'); idx == -1 {
   372  		addr = fmt.Sprintf("%s:%d", host, defaultKmsPort)
   373  	}
   374  
   375  	kmsProvider := kmsCtx.KMSProvider()
   376  	tlsCfg := c.tlsConfig[kmsProvider]
   377  	if tlsCfg == nil {
   378  		tlsCfg = &tls.Config{MinVersion: tls.VersionTLS12}
   379  	}
   380  	conn, err := tls.Dial("tcp", addr, tlsCfg)
   381  	if err != nil {
   382  		return err
   383  	}
   384  	defer func() {
   385  		_ = conn.Close()
   386  	}()
   387  
   388  	if err = conn.SetWriteDeadline(time.Now().Add(defaultKmsTimeout)); err != nil {
   389  		return err
   390  	}
   391  	if _, err = conn.Write(msg); err != nil {
   392  		return err
   393  	}
   394  
   395  	for {
   396  		bytesNeeded := kmsCtx.BytesNeeded()
   397  		if bytesNeeded == 0 {
   398  			return nil
   399  		}
   400  
   401  		res := make([]byte, bytesNeeded)
   402  		bytesRead, err := conn.Read(res)
   403  		if err != nil && !errors.Is(err, io.EOF) {
   404  			return err
   405  		}
   406  
   407  		if err = kmsCtx.FeedResponse(res[:bytesRead]); err != nil {
   408  			return err
   409  		}
   410  	}
   411  }
   412  
   413  func (c *crypt) provideKmsProviders(ctx context.Context, cryptCtx *mongocrypt.Context) error {
   414  	kmsProviders, err := c.mongoCrypt.GetKmsProviders(ctx)
   415  	if err != nil {
   416  		return err
   417  	}
   418  	return cryptCtx.ProvideKmsProviders(kmsProviders)
   419  }
   420  

View as plain text