1 package pkcs7
2
3 import (
4 "bytes"
5 "crypto"
6 "crypto/aes"
7 "crypto/cipher"
8 "crypto/des"
9 "crypto/rand"
10 "crypto/rsa"
11 "crypto/x509"
12 "encoding/asn1"
13 "errors"
14 "fmt"
15 )
16
17
18 var ErrUnsupportedAlgorithm = errors.New("pkcs7: cannot decrypt data: only RSA, DES, DES-EDE3, AES-256-CBC and AES-128-GCM supported")
19
20
21 var ErrNotEncryptedContent = errors.New("pkcs7: content data is a decryptable data type")
22
23
24 func (p7 *PKCS7) Decrypt(cert *x509.Certificate, pkey crypto.PrivateKey) ([]byte, error) {
25 data, ok := p7.raw.(envelopedData)
26 if !ok {
27 return nil, ErrNotEncryptedContent
28 }
29 recipient := selectRecipientForCertificate(data.RecipientInfos, cert)
30 if recipient.EncryptedKey == nil {
31 return nil, errors.New("pkcs7: no enveloped recipient for provided certificate")
32 }
33 switch pkey := pkey.(type) {
34 case *rsa.PrivateKey:
35 var contentKey []byte
36 contentKey, err := rsa.DecryptPKCS1v15(rand.Reader, pkey, recipient.EncryptedKey)
37 if err != nil {
38 return nil, err
39 }
40 return data.EncryptedContentInfo.decrypt(contentKey)
41 }
42 return nil, ErrUnsupportedAlgorithm
43 }
44
45
46
47 func (p7 *PKCS7) DecryptUsingPSK(key []byte) ([]byte, error) {
48 data, ok := p7.raw.(encryptedData)
49 if !ok {
50 return nil, ErrNotEncryptedContent
51 }
52 return data.EncryptedContentInfo.decrypt(key)
53 }
54
55 func (eci encryptedContentInfo) decrypt(key []byte) ([]byte, error) {
56 alg := eci.ContentEncryptionAlgorithm.Algorithm
57 if !alg.Equal(OIDEncryptionAlgorithmDESCBC) &&
58 !alg.Equal(OIDEncryptionAlgorithmDESEDE3CBC) &&
59 !alg.Equal(OIDEncryptionAlgorithmAES256CBC) &&
60 !alg.Equal(OIDEncryptionAlgorithmAES128CBC) &&
61 !alg.Equal(OIDEncryptionAlgorithmAES128GCM) &&
62 !alg.Equal(OIDEncryptionAlgorithmAES256GCM) {
63 fmt.Printf("Unsupported Content Encryption Algorithm: %s\n", alg)
64 return nil, ErrUnsupportedAlgorithm
65 }
66
67
68
69 var cyphertext []byte
70 if eci.EncryptedContent.IsCompound {
71
72 var buf bytes.Buffer
73 cypherbytes := eci.EncryptedContent.Bytes
74 for {
75 var part []byte
76 cypherbytes, _ = asn1.Unmarshal(cypherbytes, &part)
77 buf.Write(part)
78 if cypherbytes == nil {
79 break
80 }
81 }
82 cyphertext = buf.Bytes()
83 } else {
84
85 cyphertext = eci.EncryptedContent.Bytes
86 }
87
88 var block cipher.Block
89 var err error
90
91 switch {
92 case alg.Equal(OIDEncryptionAlgorithmDESCBC):
93 block, err = des.NewCipher(key)
94 case alg.Equal(OIDEncryptionAlgorithmDESEDE3CBC):
95 block, err = des.NewTripleDESCipher(key)
96 case alg.Equal(OIDEncryptionAlgorithmAES256CBC), alg.Equal(OIDEncryptionAlgorithmAES256GCM):
97 fallthrough
98 case alg.Equal(OIDEncryptionAlgorithmAES128GCM), alg.Equal(OIDEncryptionAlgorithmAES128CBC):
99 block, err = aes.NewCipher(key)
100 }
101
102 if err != nil {
103 return nil, err
104 }
105
106 if alg.Equal(OIDEncryptionAlgorithmAES128GCM) || alg.Equal(OIDEncryptionAlgorithmAES256GCM) {
107 params := aesGCMParameters{}
108 paramBytes := eci.ContentEncryptionAlgorithm.Parameters.Bytes
109
110 _, err := asn1.Unmarshal(paramBytes, ¶ms)
111 if err != nil {
112 return nil, err
113 }
114
115 gcm, err := cipher.NewGCM(block)
116 if err != nil {
117 return nil, err
118 }
119
120 if len(params.Nonce) != gcm.NonceSize() {
121 return nil, errors.New("pkcs7: encryption algorithm parameters are incorrect")
122 }
123 if params.ICVLen != gcm.Overhead() {
124 return nil, errors.New("pkcs7: encryption algorithm parameters are incorrect")
125 }
126
127 plaintext, err := gcm.Open(nil, params.Nonce, cyphertext, nil)
128 if err != nil {
129 return nil, err
130 }
131
132 return plaintext, nil
133 }
134
135 iv := eci.ContentEncryptionAlgorithm.Parameters.Bytes
136 if len(iv) != block.BlockSize() {
137 return nil, errors.New("pkcs7: encryption algorithm parameters are malformed")
138 }
139 mode := cipher.NewCBCDecrypter(block, iv)
140 plaintext := make([]byte, len(cyphertext))
141 mode.CryptBlocks(plaintext, cyphertext)
142 if plaintext, err = unpad(plaintext, mode.BlockSize()); err != nil {
143 return nil, err
144 }
145 return plaintext, nil
146 }
147
148 func unpad(data []byte, blocklen int) ([]byte, error) {
149 if blocklen < 1 {
150 return nil, fmt.Errorf("invalid blocklen %d", blocklen)
151 }
152 if len(data)%blocklen != 0 || len(data) == 0 {
153 return nil, fmt.Errorf("invalid data len %d", len(data))
154 }
155
156
157 padlen := int(data[len(data)-1])
158
159
160 pad := data[len(data)-padlen:]
161 for _, padbyte := range pad {
162 if padbyte != byte(padlen) {
163 return nil, errors.New("invalid padding")
164 }
165 }
166
167 return data[:len(data)-padlen], nil
168 }
169
170 func selectRecipientForCertificate(recipients []recipientInfo, cert *x509.Certificate) recipientInfo {
171 for _, recp := range recipients {
172 if isCertMatchForIssuerAndSerial(cert, recp.IssuerAndSerialNumber) {
173 return recp
174 }
175 }
176 return recipientInfo{}
177 }
178
View as plain text