1 package crypto
2
3 import (
4 "crypto/aes"
5 "crypto/cipher"
6 "crypto/rand"
7 "encoding/base64"
8 "encoding/json"
9 "io"
10 "strings"
11
12 "github.com/pkg/errors"
13 )
14
15
16 func aeadKey(key []byte) *[32]byte {
17 var result [32]byte
18 copy(result[:], key[:32])
19 return &result
20 }
21
22
23
24
25 func Encrypt(data []byte, key []byte) (string, error) {
26 ciphertext, err := encrypt(data, key)
27 if err != nil {
28 return "", errors.WithStack(err)
29 }
30
31 return base64.URLEncoding.EncodeToString(ciphertext), nil
32 }
33
34
35 func Decrypt(data string, key []byte) ([]byte, error) {
36 raw, err := base64.URLEncoding.DecodeString(data)
37 if err != nil {
38 return nil, errors.WithStack(err)
39 }
40
41 plaintext, err := decrypt(raw, key)
42 if err != nil {
43 return nil, errors.WithStack(err)
44 }
45
46 return plaintext, nil
47 }
48
49
50
51 func EncryptJSON(value json.RawMessage, key []byte) (json.RawMessage, error) {
52 encrypted, err := Encrypt(value, key)
53 if err != nil {
54 return nil, err
55 }
56
57 encryptedJSON := `{ "EncryptedData": "` + encrypted + `" }`
58 jsonMessage := json.RawMessage{}
59 err = json.Unmarshal([]byte(encryptedJSON), &jsonMessage)
60 if err != nil {
61 return nil, err
62 }
63
64 return jsonMessage, nil
65 }
66
67
68 func DecryptJSON(value json.RawMessage, key []byte) (json.RawMessage, error) {
69 msg := &map[string]string{}
70 err := json.Unmarshal(value, msg)
71 if err != nil {
72 return nil, err
73 }
74
75 encryptedData := (*msg)["EncryptedData"]
76
77 decryptedData, err := Decrypt(encryptedData, key)
78 if err != nil {
79 return nil, err
80 }
81
82 return decryptedData, nil
83 }
84
85
86 func EncryptRedis(value []byte, key []byte) (string, error) {
87 data, err := Encrypt(value, key)
88 if err != nil {
89 return "", err
90 }
91
92 return "EncryptedData:" + data, nil
93 }
94
95
96 func DecryptRedis(value string, key []byte) ([]byte, error) {
97 prefix := "EncryptedData:"
98
99 if strings.HasPrefix(value, prefix) {
100
101 data := []byte(value[len(prefix):])
102
103
104 decryptedData, err := Decrypt(string(data), key)
105 if err != nil {
106 return nil, err
107 }
108 return decryptedData, nil
109 }
110 return nil, nil
111 }
112
113
114 func encrypt(plaintext []byte, key []byte) (ciphertext []byte, err error) {
115
116 if len(key) != 32 {
117 return nil, errors.Errorf("key must be exactly 32 bytes long, got %d bytes", len(key))
118 }
119
120
121 encryptionKey := aeadKey(key)
122
123
124 block, err := aes.NewCipher(encryptionKey[:])
125 if err != nil {
126 return nil, errors.Errorf("Unable to create cipher block with key: %s", encryptionKey)
127 }
128
129
130 gcm, err := cipher.NewGCM(block)
131 if err != nil {
132 return nil, errors.Errorf("Unable to create new GCM block")
133 }
134
135 nonce := make([]byte, gcm.NonceSize())
136
137
138 _, err = io.ReadFull(rand.Reader, nonce)
139 if err != nil {
140 return nil, errors.Errorf("Unable to ReadFull")
141 }
142
143
144 return gcm.Seal(nonce, nonce, plaintext, nil), nil
145 }
146
147
148 func decrypt(ciphertext []byte, key []byte) (plaintext []byte, err error) {
149
150 if len(key) != 32 {
151 return nil, errors.Errorf("key must be exactly 32 bytes long, got %d bytes", len(key))
152 }
153
154
155 encryptionKey := aeadKey(key)
156
157
158 block, err := aes.NewCipher(encryptionKey[:])
159 if err != nil {
160 return nil, errors.Errorf("Unable to create cipher block with key: %s", encryptionKey)
161 }
162
163
164 gcm, err := cipher.NewGCM(block)
165 if err != nil {
166 return nil, errors.Errorf("Unable to create new GCM block")
167 }
168
169 if len(ciphertext) < gcm.NonceSize() {
170 return nil, errors.Errorf("malformed ciphertext")
171 }
172
173
174 return gcm.Open(nil,
175 ciphertext[:gcm.NonceSize()],
176 ciphertext[gcm.NonceSize():],
177 nil,
178 )
179 }
180
View as plain text