1
2
3
4
5
6
7 package driver
8
9 import (
10 "context"
11 "crypto/tls"
12 "errors"
13 "fmt"
14 "io"
15 "strings"
16 "time"
17
18 "go.mongodb.org/mongo-driver/bson/bsontype"
19 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
20 "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
21 "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options"
22 )
23
24 const (
25 defaultKmsPort = 443
26 defaultKmsTimeout = 10 * time.Second
27 )
28
29
30 type CollectionInfoFn func(ctx context.Context, db string, filter bsoncore.Document) (bsoncore.Document, error)
31
32
33 type KeyRetrieverFn func(ctx context.Context, filter bsoncore.Document) ([]bsoncore.Document, error)
34
35
36 type MarkCommandFn func(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
37
38
39 type CryptOptions struct {
40 MongoCrypt *mongocrypt.MongoCrypt
41 CollInfoFn CollectionInfoFn
42 KeyFn KeyRetrieverFn
43 MarkFn MarkCommandFn
44 TLSConfig map[string]*tls.Config
45 BypassAutoEncryption bool
46 BypassQueryAnalysis bool
47 }
48
49
50
51
52
53
54 type Crypt interface {
55
56 Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
57
58 Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error)
59
60 CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error)
61
62 EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error)
63
64 EncryptExplicitExpression(ctx context.Context, val bsoncore.Document, opts *options.ExplicitEncryptionOptions) (bsoncore.Document, error)
65
66 DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error)
67
68 Close()
69
70 BypassAutoEncryption() bool
71
72
73 RewrapDataKey(ctx context.Context, filter []byte, opts *options.RewrapManyDataKeyOptions) ([]bsoncore.Document, error)
74 }
75
76
77
78 type crypt struct {
79 mongoCrypt *mongocrypt.MongoCrypt
80 collInfoFn CollectionInfoFn
81 keyFn KeyRetrieverFn
82 markFn MarkCommandFn
83 tlsConfig map[string]*tls.Config
84
85 bypassAutoEncryption bool
86 }
87
88
89 func NewCrypt(opts *CryptOptions) Crypt {
90 c := &crypt{
91 mongoCrypt: opts.MongoCrypt,
92 collInfoFn: opts.CollInfoFn,
93 keyFn: opts.KeyFn,
94 markFn: opts.MarkFn,
95 tlsConfig: opts.TLSConfig,
96 bypassAutoEncryption: opts.BypassAutoEncryption,
97 }
98 return c
99 }
100
101
102 func (c *crypt) Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error) {
103 if c.bypassAutoEncryption {
104 return cmd, nil
105 }
106
107 cryptCtx, err := c.mongoCrypt.CreateEncryptionContext(db, cmd)
108 if err != nil {
109 return nil, err
110 }
111 defer cryptCtx.Close()
112
113 return c.executeStateMachine(ctx, cryptCtx, db)
114 }
115
116
117 func (c *crypt) Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error) {
118 cryptCtx, err := c.mongoCrypt.CreateDecryptionContext(cmdResponse)
119 if err != nil {
120 return nil, err
121 }
122 defer cryptCtx.Close()
123
124 return c.executeStateMachine(ctx, cryptCtx, "")
125 }
126
127
128 func (c *crypt) CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error) {
129 cryptCtx, err := c.mongoCrypt.CreateDataKeyContext(kmsProvider, opts)
130 if err != nil {
131 return nil, err
132 }
133 defer cryptCtx.Close()
134
135 return c.executeStateMachine(ctx, cryptCtx, "")
136 }
137
138
139
140 func (c *crypt) RewrapDataKey(ctx context.Context, filter []byte,
141 opts *options.RewrapManyDataKeyOptions) ([]bsoncore.Document, error) {
142
143 cryptCtx, err := c.mongoCrypt.RewrapDataKeyContext(filter, opts)
144 if err != nil {
145 return nil, err
146 }
147 defer cryptCtx.Close()
148
149 rewrappedBSON, err := c.executeStateMachine(ctx, cryptCtx, "")
150 if err != nil {
151 return nil, err
152 }
153 if rewrappedBSON == nil {
154 return nil, nil
155 }
156
157
158
159 rewrappedDocumentBytes, err := rewrappedBSON.LookupErr("v")
160 if err != nil {
161 return nil, err
162 }
163
164
165 rewrappedDocsArray, ok := rewrappedDocumentBytes.ArrayOK()
166 if !ok {
167 return nil, fmt.Errorf("expected results from mongocrypt_ctx_rewrap_many_datakey_init to be an array")
168 }
169
170 rewrappedDocumentValues, err := rewrappedDocsArray.Values()
171 if err != nil {
172 return nil, err
173 }
174
175 rewrappedDocuments := []bsoncore.Document{}
176 for _, rewrappedDocumentValue := range rewrappedDocumentValues {
177 if rewrappedDocumentValue.Type != bsontype.EmbeddedDocument {
178
179
180 return nil, fmt.Errorf("expected value of type %q, got: %q",
181 bsontype.EmbeddedDocument.String(),
182 rewrappedDocumentValue.Type.String())
183 }
184 rewrappedDocuments = append(rewrappedDocuments, rewrappedDocumentValue.Document())
185 }
186 return rewrappedDocuments, nil
187 }
188
189
190 func (c *crypt) EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error) {
191 idx, doc := bsoncore.AppendDocumentStart(nil)
192 doc = bsoncore.AppendValueElement(doc, "v", val)
193 doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
194
195 cryptCtx, err := c.mongoCrypt.CreateExplicitEncryptionContext(doc, opts)
196 if err != nil {
197 return 0, nil, err
198 }
199 defer cryptCtx.Close()
200
201 res, err := c.executeStateMachine(ctx, cryptCtx, "")
202 if err != nil {
203 return 0, nil, err
204 }
205
206 sub, data := res.Lookup("v").Binary()
207 return sub, data, nil
208 }
209
210
211 func (c *crypt) EncryptExplicitExpression(ctx context.Context, expr bsoncore.Document, opts *options.ExplicitEncryptionOptions) (bsoncore.Document, error) {
212 idx, doc := bsoncore.AppendDocumentStart(nil)
213 doc = bsoncore.AppendDocumentElement(doc, "v", expr)
214 doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
215
216 cryptCtx, err := c.mongoCrypt.CreateExplicitEncryptionExpressionContext(doc, opts)
217 if err != nil {
218 return nil, err
219 }
220 defer cryptCtx.Close()
221
222 res, err := c.executeStateMachine(ctx, cryptCtx, "")
223 if err != nil {
224 return nil, err
225 }
226
227 encryptedExpr := res.Lookup("v").Document()
228 return encryptedExpr, nil
229 }
230
231
232 func (c *crypt) DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error) {
233 idx, doc := bsoncore.AppendDocumentStart(nil)
234 doc = bsoncore.AppendBinaryElement(doc, "v", subtype, data)
235 doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
236
237 cryptCtx, err := c.mongoCrypt.CreateExplicitDecryptionContext(doc)
238 if err != nil {
239 return bsoncore.Value{}, err
240 }
241 defer cryptCtx.Close()
242
243 res, err := c.executeStateMachine(ctx, cryptCtx, "")
244 if err != nil {
245 return bsoncore.Value{}, err
246 }
247
248 return res.Lookup("v"), nil
249 }
250
251
252 func (c *crypt) Close() {
253 c.mongoCrypt.Close()
254 }
255
256 func (c *crypt) BypassAutoEncryption() bool {
257 return c.bypassAutoEncryption
258 }
259
260 func (c *crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Context, db string) (bsoncore.Document, error) {
261 var err error
262 for {
263 state := cryptCtx.State()
264 switch state {
265 case mongocrypt.NeedMongoCollInfo:
266 err = c.collectionInfo(ctx, cryptCtx, db)
267 case mongocrypt.NeedMongoMarkings:
268 err = c.markCommand(ctx, cryptCtx, db)
269 case mongocrypt.NeedMongoKeys:
270 err = c.retrieveKeys(ctx, cryptCtx)
271 case mongocrypt.NeedKms:
272 err = c.decryptKeys(cryptCtx)
273 case mongocrypt.Ready:
274 return cryptCtx.Finish()
275 case mongocrypt.Done:
276 return nil, nil
277 case mongocrypt.NeedKmsCredentials:
278 err = c.provideKmsProviders(ctx, cryptCtx)
279 default:
280 return nil, fmt.Errorf("invalid Crypt state: %v", state)
281 }
282 if err != nil {
283 return nil, err
284 }
285 }
286 }
287
288 func (c *crypt) collectionInfo(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
289 op, err := cryptCtx.NextOperation()
290 if err != nil {
291 return err
292 }
293
294 collInfo, err := c.collInfoFn(ctx, db, op)
295 if err != nil {
296 return err
297 }
298 if collInfo != nil {
299 if err = cryptCtx.AddOperationResult(collInfo); err != nil {
300 return err
301 }
302 }
303
304 return cryptCtx.CompleteOperation()
305 }
306
307 func (c *crypt) markCommand(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
308 op, err := cryptCtx.NextOperation()
309 if err != nil {
310 return err
311 }
312
313 markedCmd, err := c.markFn(ctx, db, op)
314 if err != nil {
315 return err
316 }
317 if err = cryptCtx.AddOperationResult(markedCmd); err != nil {
318 return err
319 }
320
321 return cryptCtx.CompleteOperation()
322 }
323
324 func (c *crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context) error {
325 op, err := cryptCtx.NextOperation()
326 if err != nil {
327 return err
328 }
329
330 keys, err := c.keyFn(ctx, op)
331 if err != nil {
332 return err
333 }
334
335 for _, key := range keys {
336 if err = cryptCtx.AddOperationResult(key); err != nil {
337 return err
338 }
339 }
340
341 return cryptCtx.CompleteOperation()
342 }
343
344 func (c *crypt) decryptKeys(cryptCtx *mongocrypt.Context) error {
345 for {
346 kmsCtx := cryptCtx.NextKmsContext()
347 if kmsCtx == nil {
348 break
349 }
350
351 if err := c.decryptKey(kmsCtx); err != nil {
352 return err
353 }
354 }
355
356 return cryptCtx.FinishKmsContexts()
357 }
358
359 func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
360 host, err := kmsCtx.HostName()
361 if err != nil {
362 return err
363 }
364 msg, err := kmsCtx.Message()
365 if err != nil {
366 return err
367 }
368
369
370 addr := host
371 if idx := strings.IndexByte(host, ':'); idx == -1 {
372 addr = fmt.Sprintf("%s:%d", host, defaultKmsPort)
373 }
374
375 kmsProvider := kmsCtx.KMSProvider()
376 tlsCfg := c.tlsConfig[kmsProvider]
377 if tlsCfg == nil {
378 tlsCfg = &tls.Config{MinVersion: tls.VersionTLS12}
379 }
380 conn, err := tls.Dial("tcp", addr, tlsCfg)
381 if err != nil {
382 return err
383 }
384 defer func() {
385 _ = conn.Close()
386 }()
387
388 if err = conn.SetWriteDeadline(time.Now().Add(defaultKmsTimeout)); err != nil {
389 return err
390 }
391 if _, err = conn.Write(msg); err != nil {
392 return err
393 }
394
395 for {
396 bytesNeeded := kmsCtx.BytesNeeded()
397 if bytesNeeded == 0 {
398 return nil
399 }
400
401 res := make([]byte, bytesNeeded)
402 bytesRead, err := conn.Read(res)
403 if err != nil && !errors.Is(err, io.EOF) {
404 return err
405 }
406
407 if err = kmsCtx.FeedResponse(res[:bytesRead]); err != nil {
408 return err
409 }
410 }
411 }
412
413 func (c *crypt) provideKmsProviders(ctx context.Context, cryptCtx *mongocrypt.Context) error {
414 kmsProviders, err := c.mongoCrypt.GetKmsProviders(ctx)
415 if err != nil {
416 return err
417 }
418 return cryptCtx.ProvideKmsProviders(kmsProviders)
419 }
420
View as plain text