1
2
3
4
5
6
7 package mongo
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "strings"
14
15 "go.mongodb.org/mongo-driver/bson"
16 "go.mongodb.org/mongo-driver/bson/bsonrw"
17 "go.mongodb.org/mongo-driver/bson/primitive"
18 "go.mongodb.org/mongo-driver/mongo/options"
19 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
20 "go.mongodb.org/mongo-driver/x/mongo/driver"
21 "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
22 mcopts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options"
23 )
24
25
26 type ClientEncryption struct {
27 crypt driver.Crypt
28 keyVaultClient *Client
29 keyVaultColl *Collection
30 }
31
32
33 func NewClientEncryption(keyVaultClient *Client, opts ...*options.ClientEncryptionOptions) (*ClientEncryption, error) {
34 if keyVaultClient == nil {
35 return nil, errors.New("keyVaultClient must not be nil")
36 }
37
38 ce := &ClientEncryption{
39 keyVaultClient: keyVaultClient,
40 }
41 ceo := options.MergeClientEncryptionOptions(opts...)
42
43
44 db, coll := splitNamespace(ceo.KeyVaultNamespace)
45 ce.keyVaultColl = ce.keyVaultClient.Database(db).Collection(coll, keyVaultCollOpts)
46
47 kmsProviders, err := marshal(ceo.KmsProviders, nil, nil)
48 if err != nil {
49 return nil, fmt.Errorf("error creating KMS providers map: %w", err)
50 }
51
52 mc, err := mongocrypt.NewMongoCrypt(mcopts.MongoCrypt().
53 SetKmsProviders(kmsProviders).
54
55
56
57 SetCryptSharedLibDisabled(true).
58 SetHTTPClient(ceo.HTTPClient))
59 if err != nil {
60 return nil, err
61 }
62
63
64 kr := keyRetriever{coll: ce.keyVaultColl}
65 cir := collInfoRetriever{client: ce.keyVaultClient}
66 ce.crypt = driver.NewCrypt(&driver.CryptOptions{
67 MongoCrypt: mc,
68 KeyFn: kr.cryptKeys,
69 CollInfoFn: cir.cryptCollInfo,
70 TLSConfig: ceo.TLSConfig,
71 })
72
73 return ce, nil
74 }
75
76
77
78 func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context,
79 db *Database, coll string, createOpts *options.CreateCollectionOptions,
80 kmsProvider string, masterKey interface{}) (*Collection, bson.M, error) {
81 if createOpts == nil {
82 return nil, nil, errors.New("nil CreateCollectionOptions")
83 }
84 ef := createOpts.EncryptedFields
85 if ef == nil {
86 return nil, nil, errors.New("no EncryptedFields defined for the collection")
87 }
88
89 efBSON, err := marshal(ef, db.bsonOpts, db.registry)
90 if err != nil {
91 return nil, nil, err
92 }
93 r := bsonrw.NewBSONDocumentReader(efBSON)
94 dec, err := bson.NewDecoder(r)
95 if err != nil {
96 return nil, nil, err
97 }
98 var m bson.M
99 err = dec.Decode(&m)
100 if err != nil {
101 return nil, nil, err
102 }
103
104 if v, ok := m["fields"]; ok {
105 if fields, ok := v.(bson.A); ok {
106 for _, field := range fields {
107 if f, ok := field.(bson.M); !ok {
108 continue
109 } else if v, ok := f["keyId"]; ok && v == nil {
110 dkOpts := options.DataKey()
111 if masterKey != nil {
112 dkOpts.SetMasterKey(masterKey)
113 }
114 keyid, err := ce.CreateDataKey(ctx, kmsProvider, dkOpts)
115 if err != nil {
116 createOpts.EncryptedFields = m
117 return nil, m, err
118 }
119 f["keyId"] = keyid
120 }
121 }
122 createOpts.EncryptedFields = m
123 }
124 }
125 err = db.CreateCollection(ctx, coll, createOpts)
126 if err != nil {
127 return nil, m, err
128 }
129 return db.Collection(coll), m, nil
130 }
131
132
133
134 func (ce *ClientEncryption) AddKeyAltName(ctx context.Context, id primitive.Binary, keyAltName string) *SingleResult {
135 filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
136 keyAltNameDoc := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build()
137 update := bsoncore.NewDocumentBuilder().AppendDocument("$addToSet", keyAltNameDoc).Build()
138 return ce.keyVaultColl.FindOneAndUpdate(ctx, filter, update)
139 }
140
141
142
143 func (ce *ClientEncryption) CreateDataKey(ctx context.Context, kmsProvider string,
144 opts ...*options.DataKeyOptions) (primitive.Binary, error) {
145
146
147 dko := options.MergeDataKeyOptions(opts...)
148 co := mcopts.DataKey().SetKeyAltNames(dko.KeyAltNames)
149 if dko.MasterKey != nil {
150 keyDoc, err := marshal(
151 dko.MasterKey,
152 ce.keyVaultClient.bsonOpts,
153 ce.keyVaultClient.registry)
154 if err != nil {
155 return primitive.Binary{}, err
156 }
157 co.SetMasterKey(keyDoc)
158 }
159 if dko.KeyMaterial != nil {
160 co.SetKeyMaterial(dko.KeyMaterial)
161 }
162
163
164 dataKeyDoc, err := ce.crypt.CreateDataKey(ctx, kmsProvider, co)
165 if err != nil {
166 return primitive.Binary{}, err
167 }
168
169
170 _, err = ce.keyVaultColl.InsertOne(ctx, dataKeyDoc)
171 if err != nil {
172 return primitive.Binary{}, err
173 }
174
175 subtype, data := bson.Raw(dataKeyDoc).Lookup("_id").Binary()
176 return primitive.Binary{Subtype: subtype, Data: data}, nil
177 }
178
179
180 func transformExplicitEncryptionOptions(opts ...*options.EncryptOptions) *mcopts.ExplicitEncryptionOptions {
181 eo := options.MergeEncryptOptions(opts...)
182 transformed := mcopts.ExplicitEncryption()
183 if eo.KeyID != nil {
184 transformed.SetKeyID(*eo.KeyID)
185 }
186 if eo.KeyAltName != nil {
187 transformed.SetKeyAltName(*eo.KeyAltName)
188 }
189 transformed.SetAlgorithm(eo.Algorithm)
190 transformed.SetQueryType(eo.QueryType)
191
192 if eo.ContentionFactor != nil {
193 transformed.SetContentionFactor(*eo.ContentionFactor)
194 }
195
196 if eo.RangeOptions != nil {
197 var transformedRange mcopts.ExplicitRangeOptions
198 if eo.RangeOptions.Min != nil {
199 transformedRange.Min = &bsoncore.Value{Type: eo.RangeOptions.Min.Type, Data: eo.RangeOptions.Min.Value}
200 }
201 if eo.RangeOptions.Max != nil {
202 transformedRange.Max = &bsoncore.Value{Type: eo.RangeOptions.Max.Type, Data: eo.RangeOptions.Max.Value}
203 }
204 if eo.RangeOptions.Precision != nil {
205 transformedRange.Precision = eo.RangeOptions.Precision
206 }
207 transformedRange.Sparsity = eo.RangeOptions.Sparsity
208 transformed.SetRangeOptions(transformedRange)
209 }
210 return transformed
211 }
212
213
214 func (ce *ClientEncryption) Encrypt(ctx context.Context, val bson.RawValue,
215 opts ...*options.EncryptOptions) (primitive.Binary, error) {
216
217 transformed := transformExplicitEncryptionOptions(opts...)
218 subtype, data, err := ce.crypt.EncryptExplicit(ctx, bsoncore.Value{Type: val.Type, Data: val.Value}, transformed)
219 if err != nil {
220 return primitive.Binary{}, err
221 }
222 return primitive.Binary{Subtype: subtype, Data: data}, nil
223 }
224
225
226
227
228
229
230
231
232
233
234
235 func (ce *ClientEncryption) EncryptExpression(ctx context.Context, expr interface{}, result interface{}, opts ...*options.EncryptOptions) error {
236 transformed := transformExplicitEncryptionOptions(opts...)
237
238 exprDoc, err := marshal(expr, nil, nil)
239 if err != nil {
240 return err
241 }
242
243 encryptedExprDoc, err := ce.crypt.EncryptExplicitExpression(ctx, exprDoc, transformed)
244 if err != nil {
245 return err
246 }
247 if raw, ok := result.(*bson.Raw); ok {
248
249 *raw = bson.Raw(encryptedExprDoc)
250 return nil
251 }
252 err = bson.Unmarshal([]byte(encryptedExprDoc), result)
253 if err != nil {
254 return err
255 }
256 return nil
257 }
258
259
260 func (ce *ClientEncryption) Decrypt(ctx context.Context, val primitive.Binary) (bson.RawValue, error) {
261 decrypted, err := ce.crypt.DecryptExplicit(ctx, val.Subtype, val.Data)
262 if err != nil {
263 return bson.RawValue{}, err
264 }
265
266 return bson.RawValue{Type: decrypted.Type, Value: decrypted.Data}, nil
267 }
268
269
270
271 func (ce *ClientEncryption) Close(ctx context.Context) error {
272 ce.crypt.Close()
273 return ce.keyVaultClient.Disconnect(ctx)
274 }
275
276
277
278 func (ce *ClientEncryption) DeleteKey(ctx context.Context, id primitive.Binary) (*DeleteResult, error) {
279 filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
280 return ce.keyVaultColl.DeleteOne(ctx, filter)
281 }
282
283
284 func (ce *ClientEncryption) GetKeyByAltName(ctx context.Context, keyAltName string) *SingleResult {
285 filter := bsoncore.NewDocumentBuilder().AppendString("keyAltNames", keyAltName).Build()
286 return ce.keyVaultColl.FindOne(ctx, filter)
287 }
288
289
290
291 func (ce *ClientEncryption) GetKey(ctx context.Context, id primitive.Binary) *SingleResult {
292 filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
293 return ce.keyVaultColl.FindOne(ctx, filter)
294 }
295
296
297
298 func (ce *ClientEncryption) GetKeys(ctx context.Context) (*Cursor, error) {
299 return ce.keyVaultColl.Find(ctx, bson.D{})
300 }
301
302
303
304 func (ce *ClientEncryption) RemoveKeyAltName(ctx context.Context, id primitive.Binary, keyAltName string) *SingleResult {
305 filter := bsoncore.NewDocumentBuilder().AppendBinary("_id", id.Subtype, id.Data).Build()
306 update := bson.A{bson.D{{"$set", bson.D{{"keyAltNames", bson.D{{"$cond", bson.A{bson.D{{"$eq",
307 bson.A{"$keyAltNames", bson.A{keyAltName}}}}, "$$REMOVE", bson.D{{"$filter",
308 bson.D{{"input", "$keyAltNames"}, {"cond", bson.D{{"$ne", bson.A{"$$this", keyAltName}}}}}}}}}}}}}}}
309 return ce.keyVaultColl.FindOneAndUpdate(ctx, filter, update)
310 }
311
312
313 func setRewrapManyDataKeyWriteModels(rewrappedDocuments []bsoncore.Document, writeModels *[]WriteModel) error {
314 const idKey = "_id"
315 const keyMaterial = "keyMaterial"
316 const masterKey = "masterKey"
317
318 if writeModels == nil {
319 return fmt.Errorf("writeModels pointer not set for location referenced")
320 }
321
322
323 for _, rewrappedDocument := range rewrappedDocuments {
324
325 masterKeyValue, err := rewrappedDocument.LookupErr(masterKey)
326 if err != nil {
327 return err
328 }
329 masterKeyDoc := masterKeyValue.Document()
330
331
332 keyMaterialValue, err := rewrappedDocument.LookupErr(keyMaterial)
333 if err != nil {
334 return err
335 }
336 keyMaterialSubtype, keyMaterialData := keyMaterialValue.Binary()
337 keyMaterialBinary := primitive.Binary{Subtype: keyMaterialSubtype, Data: keyMaterialData}
338
339
340 id, err := rewrappedDocument.LookupErr(idKey)
341 if err != nil {
342 return err
343 }
344
345 idSubtype, idData, ok := id.BinaryOK()
346 if !ok {
347 return fmt.Errorf("expected to assert %q as binary, got type %T", idKey, id)
348 }
349 binaryID := primitive.Binary{Subtype: idSubtype, Data: idData}
350
351
352 *writeModels = append(*writeModels, NewUpdateOneModel().
353 SetFilter(bson.D{{idKey, binaryID}}).
354 SetUpdate(
355 bson.D{
356 {"$set", bson.D{{keyMaterial, keyMaterialBinary}, {masterKey, masterKeyDoc}}},
357 {"$currentDate", bson.D{{"updateDate", true}}},
358 },
359 ))
360 }
361 return nil
362 }
363
364
365
366
367
368 func (ce *ClientEncryption) RewrapManyDataKey(ctx context.Context, filter interface{},
369 opts ...*options.RewrapManyDataKeyOptions) (*RewrapManyDataKeyResult, error) {
370
371
372
373 libmongocryptVersion := mongocrypt.Version()
374 if strings.HasPrefix(libmongocryptVersion, "1.5.0") || strings.HasPrefix(libmongocryptVersion, "1.5.1") {
375 return nil, fmt.Errorf("RewrapManyDataKey requires libmongocrypt 1.5.2 or newer. Detected version: %v", libmongocryptVersion)
376 }
377
378 rmdko := options.MergeRewrapManyDataKeyOptions(opts...)
379 if ctx == nil {
380 ctx = context.Background()
381 }
382
383
384 co := mcopts.RewrapManyDataKey()
385 if rmdko.MasterKey != nil {
386 keyDoc, err := marshal(
387 rmdko.MasterKey,
388 ce.keyVaultClient.bsonOpts,
389 ce.keyVaultClient.registry)
390 if err != nil {
391 return nil, err
392 }
393 co.SetMasterKey(keyDoc)
394 }
395 if rmdko.Provider != nil {
396 co.SetProvider(*rmdko.Provider)
397 }
398
399
400 filterdoc, err := marshal(filter, ce.keyVaultClient.bsonOpts, ce.keyVaultClient.registry)
401 if err != nil {
402 return nil, err
403 }
404
405 rewrappedDocuments, err := ce.crypt.RewrapDataKey(ctx, filterdoc, co)
406 if err != nil {
407 return nil, err
408 }
409 if len(rewrappedDocuments) == 0 {
410
411 return new(RewrapManyDataKeyResult), nil
412 }
413
414
415 models := []WriteModel{}
416 if err := setRewrapManyDataKeyWriteModels(rewrappedDocuments, &models); err != nil {
417 return nil, err
418 }
419
420 bulkWriteResults, err := ce.keyVaultColl.BulkWrite(ctx, models)
421 return &RewrapManyDataKeyResult{BulkWriteResult: bulkWriteResults}, err
422 }
423
424
425 func splitNamespace(ns string) (string, string) {
426 firstDot := strings.Index(ns, ".")
427 if firstDot == -1 {
428 return "", ns
429 }
430
431 return ns[:firstDot], ns[firstDot+1:]
432 }
433
View as plain text