...

Source file src/github.com/digitorus/pkcs7/decrypt.go

Documentation: github.com/digitorus/pkcs7

     1  package pkcs7
     2  
     3  import (
     4  	"bytes"
     5  	"crypto"
     6  	"crypto/aes"
     7  	"crypto/cipher"
     8  	"crypto/des"
     9  	"crypto/rand"
    10  	"crypto/rsa"
    11  	"crypto/x509"
    12  	"encoding/asn1"
    13  	"errors"
    14  	"fmt"
    15  )
    16  
    17  // ErrUnsupportedAlgorithm tells you when our quick dev assumptions have failed
    18  var ErrUnsupportedAlgorithm = errors.New("pkcs7: cannot decrypt data: only RSA, DES, DES-EDE3, AES-256-CBC and AES-128-GCM supported")
    19  
    20  // ErrNotEncryptedContent is returned when attempting to Decrypt data that is not encrypted data
    21  var ErrNotEncryptedContent = errors.New("pkcs7: content data is a decryptable data type")
    22  
    23  // Decrypt decrypts encrypted content info for recipient cert and private key
    24  func (p7 *PKCS7) Decrypt(cert *x509.Certificate, pkey crypto.PrivateKey) ([]byte, error) {
    25  	data, ok := p7.raw.(envelopedData)
    26  	if !ok {
    27  		return nil, ErrNotEncryptedContent
    28  	}
    29  	recipient := selectRecipientForCertificate(data.RecipientInfos, cert)
    30  	if recipient.EncryptedKey == nil {
    31  		return nil, errors.New("pkcs7: no enveloped recipient for provided certificate")
    32  	}
    33  	switch pkey := pkey.(type) {
    34  	case *rsa.PrivateKey:
    35  		var contentKey []byte
    36  		contentKey, err := rsa.DecryptPKCS1v15(rand.Reader, pkey, recipient.EncryptedKey)
    37  		if err != nil {
    38  			return nil, err
    39  		}
    40  		return data.EncryptedContentInfo.decrypt(contentKey)
    41  	}
    42  	return nil, ErrUnsupportedAlgorithm
    43  }
    44  
    45  // DecryptUsingPSK decrypts encrypted data using caller provided
    46  // pre-shared secret
    47  func (p7 *PKCS7) DecryptUsingPSK(key []byte) ([]byte, error) {
    48  	data, ok := p7.raw.(encryptedData)
    49  	if !ok {
    50  		return nil, ErrNotEncryptedContent
    51  	}
    52  	return data.EncryptedContentInfo.decrypt(key)
    53  }
    54  
    55  func (eci encryptedContentInfo) decrypt(key []byte) ([]byte, error) {
    56  	alg := eci.ContentEncryptionAlgorithm.Algorithm
    57  	if !alg.Equal(OIDEncryptionAlgorithmDESCBC) &&
    58  		!alg.Equal(OIDEncryptionAlgorithmDESEDE3CBC) &&
    59  		!alg.Equal(OIDEncryptionAlgorithmAES256CBC) &&
    60  		!alg.Equal(OIDEncryptionAlgorithmAES128CBC) &&
    61  		!alg.Equal(OIDEncryptionAlgorithmAES128GCM) &&
    62  		!alg.Equal(OIDEncryptionAlgorithmAES256GCM) {
    63  		fmt.Printf("Unsupported Content Encryption Algorithm: %s\n", alg)
    64  		return nil, ErrUnsupportedAlgorithm
    65  	}
    66  
    67  	// EncryptedContent can either be constructed of multple OCTET STRINGs
    68  	// or _be_ a tagged OCTET STRING
    69  	var cyphertext []byte
    70  	if eci.EncryptedContent.IsCompound {
    71  		// Complex case to concat all of the children OCTET STRINGs
    72  		var buf bytes.Buffer
    73  		cypherbytes := eci.EncryptedContent.Bytes
    74  		for {
    75  			var part []byte
    76  			cypherbytes, _ = asn1.Unmarshal(cypherbytes, &part)
    77  			buf.Write(part)
    78  			if cypherbytes == nil {
    79  				break
    80  			}
    81  		}
    82  		cyphertext = buf.Bytes()
    83  	} else {
    84  		// Simple case, the bytes _are_ the cyphertext
    85  		cyphertext = eci.EncryptedContent.Bytes
    86  	}
    87  
    88  	var block cipher.Block
    89  	var err error
    90  
    91  	switch {
    92  	case alg.Equal(OIDEncryptionAlgorithmDESCBC):
    93  		block, err = des.NewCipher(key)
    94  	case alg.Equal(OIDEncryptionAlgorithmDESEDE3CBC):
    95  		block, err = des.NewTripleDESCipher(key)
    96  	case alg.Equal(OIDEncryptionAlgorithmAES256CBC), alg.Equal(OIDEncryptionAlgorithmAES256GCM):
    97  		fallthrough
    98  	case alg.Equal(OIDEncryptionAlgorithmAES128GCM), alg.Equal(OIDEncryptionAlgorithmAES128CBC):
    99  		block, err = aes.NewCipher(key)
   100  	}
   101  
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	if alg.Equal(OIDEncryptionAlgorithmAES128GCM) || alg.Equal(OIDEncryptionAlgorithmAES256GCM) {
   107  		params := aesGCMParameters{}
   108  		paramBytes := eci.ContentEncryptionAlgorithm.Parameters.Bytes
   109  
   110  		_, err := asn1.Unmarshal(paramBytes, &params)
   111  		if err != nil {
   112  			return nil, err
   113  		}
   114  
   115  		gcm, err := cipher.NewGCM(block)
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  
   120  		if len(params.Nonce) != gcm.NonceSize() {
   121  			return nil, errors.New("pkcs7: encryption algorithm parameters are incorrect")
   122  		}
   123  		if params.ICVLen != gcm.Overhead() {
   124  			return nil, errors.New("pkcs7: encryption algorithm parameters are incorrect")
   125  		}
   126  
   127  		plaintext, err := gcm.Open(nil, params.Nonce, cyphertext, nil)
   128  		if err != nil {
   129  			return nil, err
   130  		}
   131  
   132  		return plaintext, nil
   133  	}
   134  
   135  	iv := eci.ContentEncryptionAlgorithm.Parameters.Bytes
   136  	if len(iv) != block.BlockSize() {
   137  		return nil, errors.New("pkcs7: encryption algorithm parameters are malformed")
   138  	}
   139  	mode := cipher.NewCBCDecrypter(block, iv)
   140  	plaintext := make([]byte, len(cyphertext))
   141  	mode.CryptBlocks(plaintext, cyphertext)
   142  	if plaintext, err = unpad(plaintext, mode.BlockSize()); err != nil {
   143  		return nil, err
   144  	}
   145  	return plaintext, nil
   146  }
   147  
   148  func unpad(data []byte, blocklen int) ([]byte, error) {
   149  	if blocklen < 1 {
   150  		return nil, fmt.Errorf("invalid blocklen %d", blocklen)
   151  	}
   152  	if len(data)%blocklen != 0 || len(data) == 0 {
   153  		return nil, fmt.Errorf("invalid data len %d", len(data))
   154  	}
   155  
   156  	// the last byte is the length of padding
   157  	padlen := int(data[len(data)-1])
   158  
   159  	// check padding integrity, all bytes should be the same
   160  	pad := data[len(data)-padlen:]
   161  	for _, padbyte := range pad {
   162  		if padbyte != byte(padlen) {
   163  			return nil, errors.New("invalid padding")
   164  		}
   165  	}
   166  
   167  	return data[:len(data)-padlen], nil
   168  }
   169  
   170  func selectRecipientForCertificate(recipients []recipientInfo, cert *x509.Certificate) recipientInfo {
   171  	for _, recp := range recipients {
   172  		if isCertMatchForIssuerAndSerial(cert, recp.IssuerAndSerialNumber) {
   173  			return recp
   174  		}
   175  	}
   176  	return recipientInfo{}
   177  }
   178  

View as plain text