...

Source file src/go.mongodb.org/mongo-driver/cmd/testkms/main.go

Documentation: go.mongodb.org/mongo-driver/cmd/testkms

     1  // Copyright (C) MongoDB, Inc. 2022-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package main
     8  
     9  import (
    10  	"context"
    11  	"fmt"
    12  	"os"
    13  	"strings"
    14  
    15  	"go.mongodb.org/mongo-driver/bson"
    16  	"go.mongodb.org/mongo-driver/bson/primitive"
    17  	"go.mongodb.org/mongo-driver/mongo"
    18  	"go.mongodb.org/mongo-driver/mongo/options"
    19  )
    20  
    21  var datakeyopts = map[string]primitive.M{
    22  	"aws": bson.M{
    23  		"region": "us-east-1",
    24  		"key":    "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0",
    25  	},
    26  	"azure": bson.M{
    27  		"keyVaultEndpoint": "",
    28  		"keyName":          "",
    29  	},
    30  	"gcp": bson.M{
    31  		"projectId": "devprod-drivers",
    32  		"location":  "global",
    33  		"keyRing":   "key-ring-csfle",
    34  		"keyName":   "key-name-csfle",
    35  	},
    36  }
    37  
    38  func main() {
    39  	uri := os.Getenv("MONGODB_URI")
    40  	provider := os.Getenv("PROVIDER")
    41  	// expecterror is an expect error substring. Set to empty string to expect no error.
    42  	expecterror := os.Getenv("EXPECT_ERROR")
    43  
    44  	datakeyopt, validKmsProvider := datakeyopts[provider]
    45  	ok := false
    46  	switch {
    47  	case uri == "":
    48  		fmt.Println("ERROR: Please set required MONGODB_URI environment variable.")
    49  	case provider == "":
    50  		fmt.Println("ERROR: Please set required PROVIDER environment variable.")
    51  	case !validKmsProvider:
    52  		fmt.Println("ERROR: Unsupported PROVIDER value.")
    53  	default:
    54  		ok = true
    55  	}
    56  	if provider == "azure" {
    57  		azureKmsKeyName := os.Getenv("AZUREKMS_KEY_NAME")
    58  		azureKmsKeyVaultEndpoint := os.Getenv("AZUREKMS_KEY_VAULT_ENDPOINT")
    59  		if azureKmsKeyName == "" {
    60  			fmt.Println("ERROR: Please set required AZUREKMS_KEY_NAME environment variable.")
    61  			ok = false
    62  		}
    63  		if azureKmsKeyVaultEndpoint == "" {
    64  			fmt.Println("ERROR: Please set required AZUREKMS_KEY_VAULT_ENDPOINT environment variable.")
    65  			ok = false
    66  		}
    67  		datakeyopts["azure"]["keyName"] = azureKmsKeyName
    68  		datakeyopts["azure"]["keyVaultEndpoint"] = azureKmsKeyVaultEndpoint
    69  	}
    70  	if !ok {
    71  		providers := make([]string, 0, len(datakeyopts))
    72  		for p := range datakeyopts {
    73  			providers = append(providers, p)
    74  		}
    75  
    76  		fmt.Println("The following environment variables are understood:")
    77  		fmt.Println("- MONGODB_URI as a MongoDB URI. Example: 'mongodb://localhost:27017'")
    78  		fmt.Println("- EXPECT_ERROR as an optional expected error substring.")
    79  		fmt.Println("- PROVIDER as a KMS provider, which supports:", strings.Join(providers, ", "))
    80  		fmt.Println("- AZUREKMS_KEY_NAME as the Azure key name. Required if PROVIDER=azure.")
    81  		fmt.Println("- AZUREKMS_KEY_VAULT_ENDPOINT as the Azure key name. Required if PROVIDER=azure.")
    82  		os.Exit(1)
    83  	}
    84  
    85  	cOpts := options.Client().ApplyURI(uri)
    86  	keyVaultClient, err := mongo.Connect(context.Background(), cOpts)
    87  	if err != nil {
    88  		panic(fmt.Sprintf("Connect error: %v", err))
    89  	}
    90  	defer func() { _ = keyVaultClient.Disconnect(context.Background()) }()
    91  
    92  	kmsProvidersMap := map[string]map[string]interface{}{
    93  		provider: {},
    94  	}
    95  	ceOpts := options.ClientEncryption().SetKmsProviders(kmsProvidersMap).SetKeyVaultNamespace("keyvault.datakeys")
    96  	ce, err := mongo.NewClientEncryption(keyVaultClient, ceOpts)
    97  	if err != nil {
    98  		panic(fmt.Sprintf("Error in NewClientEncryption: %v", err))
    99  	}
   100  	dkOpts := options.DataKey().SetMasterKey(datakeyopt)
   101  	_, err = ce.CreateDataKey(context.Background(), provider, dkOpts)
   102  	if expecterror == "" {
   103  		if err != nil {
   104  			panic(fmt.Sprintf("Expected success, but got error in CreateDataKey: %v", err))
   105  		}
   106  	} else {
   107  		if err == nil {
   108  			panic(fmt.Sprintf("Expected error message to contain %q, but got no error", expecterror))
   109  		}
   110  		if !strings.Contains(err.Error(), expecterror) {
   111  			panic(fmt.Sprintf("Expected error message to contain %q, but got %q", expecterror, err.Error()))
   112  		}
   113  	}
   114  }
   115  

View as plain text